@@ -943,9 +943,15 @@ struct RenormTempStorage {
943
943
float max_val;
944
944
float min_val;
945
945
union {
946
- float value;
947
- int count;
948
- ValueCount<float > pair;
946
+ struct {
947
+ float values[2 ];
948
+ };
949
+ struct {
950
+ int counts[2 ];
951
+ };
952
+ struct {
953
+ ValueCount<float > pairs[2 ];
954
+ };
949
955
} block_aggregate;
950
956
};
951
957
};
@@ -964,7 +970,6 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
964
970
reinterpret_cast <RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
965
971
temp_storage.max_val = 0 ;
966
972
vec_t <float , VEC_SIZE> probs_vec;
967
- float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
968
973
969
974
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
970
975
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(probs, row_idx, d,
@@ -981,8 +986,10 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
981
986
// stopping condition
982
987
// - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
983
988
do {
984
- float threadlocal_sum = 0 ;
985
- double mid = (low + high) / 2 ;
989
+ double pivot_0 = (high + 2 * low) / 3 ;
990
+ double pivot_1 = (2 * high + low) / 3 ;
991
+
992
+ float aggregate_gt_pivot_0 = 0 , aggregate_gt_pivot_1 = 0 ;
986
993
min_gt_low = high;
987
994
max_le_high = low;
988
995
#pragma unroll 2
@@ -991,19 +998,29 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
991
998
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
992
999
probs_vec.cast_load (probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
993
1000
}
1001
+
1002
+ float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
994
1003
#pragma unroll
995
1004
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
996
- probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : 0 ;
1005
+ probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0 ;
1006
+ probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0 ;
1007
+
997
1008
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
998
1009
min_gt_low = min (min_gt_low, probs_vec[j]);
999
1010
}
1000
1011
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1001
1012
max_le_high = max (max_le_high, probs_vec[j]);
1002
1013
}
1003
1014
}
1004
- threadlocal_sum +=
1015
+
1016
+ aggregate_gt_pivot_0 +=
1017
+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1018
+ .Sum <VEC_SIZE>(probs_gt_pivot_0);
1019
+ __syncthreads ();
1020
+
1021
+ aggregate_gt_pivot_1 +=
1005
1022
BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1006
- .Sum <VEC_SIZE>(probs_greater_than_pivot );
1023
+ .Sum <VEC_SIZE>(probs_gt_pivot_1 );
1007
1024
__syncthreads ();
1008
1025
}
1009
1026
min_gt_low = BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
@@ -1013,19 +1030,26 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
1013
1030
BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1014
1031
.Reduce (max_le_high, cub::Max ());
1015
1032
if (tx == 0 ) {
1016
- temp_storage.block_aggregate .value = threadlocal_sum;
1033
+ temp_storage.block_aggregate .values [0 ] = aggregate_gt_pivot_0;
1034
+ temp_storage.block_aggregate .values [1 ] = aggregate_gt_pivot_1;
1017
1035
temp_storage.min_val = min_gt_low;
1018
1036
temp_storage.max_val = max_le_high;
1019
1037
}
1020
1038
__syncthreads ();
1021
- threadlocal_sum = temp_storage.block_aggregate .value ;
1039
+ aggregate_gt_pivot_0 = temp_storage.block_aggregate .values [0 ];
1040
+ aggregate_gt_pivot_1 = temp_storage.block_aggregate .values [1 ];
1022
1041
min_gt_low = temp_storage.min_val ;
1023
1042
max_le_high = temp_storage.max_val ;
1024
- if (threadlocal_sum >= p) {
1025
- low = mid;
1026
- sum_low = threadlocal_sum;
1043
+
1044
+ if (aggregate_gt_pivot_1 >= p) {
1045
+ low = pivot_1;
1046
+ sum_low = aggregate_gt_pivot_1;
1047
+ } else if (aggregate_gt_pivot_0 >= p) {
1048
+ low = pivot_0;
1049
+ high = min (pivot_1, max_le_high);
1050
+ sum_low = aggregate_gt_pivot_0;
1027
1051
} else {
1028
- high = min (mid , max_le_high);
1052
+ high = min (pivot_0 , max_le_high);
1029
1053
}
1030
1054
} while (min_gt_low != max_le_high);
1031
1055
@@ -1079,9 +1103,10 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
1079
1103
// stopping condition: min_gt_low == max_le_high
1080
1104
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
1081
1105
do {
1082
- int threadlocal_count_sum = 0 ;
1083
- int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
1084
- double mid = (low + high) / 2 ;
1106
+ double pivot_0 = (high + 2 * low) / 3 ;
1107
+ double pivot_1 = (2 * high + low) / 3 ;
1108
+
1109
+ int aggregate_gt_pivot_0 = 0 , aggregate_gt_pivot_1 = 0 ;
1085
1110
min_gt_low = high;
1086
1111
max_le_high = low;
1087
1112
#pragma unroll 2
@@ -1090,20 +1115,30 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
1090
1115
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
1091
1116
logits_vec.cast_load (logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
1092
1117
}
1118
+ int probs_gt_pivot_0_count[VEC_SIZE], probs_gt_pivot_1_count[VEC_SIZE];
1093
1119
#pragma unroll
1094
1120
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
1095
- probs_greater_than_pivot_count[j] =
1096
- logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
1121
+ probs_gt_pivot_0_count[j] =
1122
+ logits_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
1123
+ probs_gt_pivot_1_count[j] =
1124
+ logits_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
1125
+
1097
1126
if (logits_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1098
1127
min_gt_low = min (min_gt_low, logits_vec[j]);
1099
1128
}
1100
1129
if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1101
1130
max_le_high = max (max_le_high, logits_vec[j]);
1102
1131
}
1103
1132
}
1104
- threadlocal_count_sum +=
1133
+
1134
+ aggregate_gt_pivot_0 +=
1135
+ BlockReduce<int , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce_int )
1136
+ .Sum <VEC_SIZE>(probs_gt_pivot_0_count);
1137
+ __syncthreads ();
1138
+
1139
+ aggregate_gt_pivot_1 +=
1105
1140
BlockReduce<int , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce_int )
1106
- .Sum <VEC_SIZE>(probs_greater_than_pivot_count );
1141
+ .Sum <VEC_SIZE>(probs_gt_pivot_1_count );
1107
1142
__syncthreads ();
1108
1143
}
1109
1144
min_gt_low =
@@ -1114,18 +1149,24 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
1114
1149
BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1115
1150
.Reduce (max_le_high, cub::Max ());
1116
1151
if (tx == 0 ) {
1117
- temp_storage.block_aggregate .count = threadlocal_count_sum;
1152
+ temp_storage.block_aggregate .counts [0 ] = aggregate_gt_pivot_0;
1153
+ temp_storage.block_aggregate .counts [1 ] = aggregate_gt_pivot_1;
1118
1154
temp_storage.min_val = min_gt_low;
1119
1155
temp_storage.max_val = max_le_high;
1120
1156
}
1121
1157
__syncthreads ();
1122
- threadlocal_count_sum = temp_storage.block_aggregate .count ;
1158
+ aggregate_gt_pivot_0 = temp_storage.block_aggregate .counts [0 ];
1159
+ aggregate_gt_pivot_1 = temp_storage.block_aggregate .counts [1 ];
1123
1160
min_gt_low = temp_storage.min_val ;
1124
1161
max_le_high = temp_storage.max_val ;
1125
- if (threadlocal_count_sum >= k) {
1126
- low = mid;
1162
+
1163
+ if (aggregate_gt_pivot_1 >= k) {
1164
+ low = pivot_1;
1165
+ } else if (aggregate_gt_pivot_0 >= k) {
1166
+ low = pivot_0;
1167
+ high = min (pivot_1, max_le_high);
1127
1168
} else {
1128
- high = min (mid , max_le_high);
1169
+ high = min (pivot_0 , max_le_high);
1129
1170
}
1130
1171
} while (min_gt_low != max_le_high);
1131
1172
pivot = low;
@@ -1164,7 +1205,6 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1164
1205
auto & temp_storage =
1165
1206
reinterpret_cast <RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
1166
1207
temp_storage.max_val = 0 ;
1167
- float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
1168
1208
1169
1209
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
1170
1210
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
@@ -1181,9 +1221,10 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1181
1221
// stopping condition: min_gt_low == max_le_high
1182
1222
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
1183
1223
do {
1184
- ValueCount<float > threadlocal_sum{0 , 0 };
1185
- ValueCount<float > probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
1186
- double mid = (low + high) / 2 ;
1224
+ double pivot_0 = (high + 2 * low) / 3 ;
1225
+ double pivot_1 = (2 * high + low) / 3 ;
1226
+
1227
+ ValueCount<float > aggregate_gt_pivot_0{0 , 0 }, aggregate_gt_pivot_1{0 , 0 };
1187
1228
min_gt_low = high;
1188
1229
max_le_high = low;
1189
1230
#pragma unroll 2
@@ -1192,21 +1233,32 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1192
1233
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
1193
1234
probs_vec.cast_load (probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
1194
1235
}
1236
+ ValueCount<float > probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
1195
1237
#pragma unroll
1196
1238
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
1197
- probs_greater_than_pivot_pair[j] = {
1198
- (probs_vec[j] > mid) ? probs_vec[j] : 0 ,
1199
- (probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1239
+ probs_gt_pivot_0_pair[j] = {
1240
+ (probs_vec[j] > pivot_0) ? probs_vec[j] : 0 ,
1241
+ (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1242
+ probs_gt_pivot_1_pair[j] = {
1243
+ (probs_vec[j] > pivot_1) ? probs_vec[j] : 0 ,
1244
+ (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1245
+
1200
1246
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1201
1247
min_gt_low = min (min_gt_low, probs_vec[j]);
1202
1248
}
1203
1249
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1204
1250
max_le_high = max (max_le_high, probs_vec[j]);
1205
1251
}
1206
1252
}
1207
- threadlocal_sum += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
1208
- temp_storage.block_prim .reduce_value_count )
1209
- .Sum <VEC_SIZE>(probs_greater_than_pivot_pair);
1253
+
1254
+ aggregate_gt_pivot_0 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
1255
+ temp_storage.block_prim .reduce_value_count )
1256
+ .Sum <VEC_SIZE>(probs_gt_pivot_0_pair);
1257
+ __syncthreads ();
1258
+
1259
+ aggregate_gt_pivot_1 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
1260
+ temp_storage.block_prim .reduce_value_count )
1261
+ .Sum <VEC_SIZE>(probs_gt_pivot_1_pair);
1210
1262
__syncthreads ();
1211
1263
}
1212
1264
min_gt_low =
@@ -1217,19 +1269,26 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1217
1269
BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1218
1270
.Reduce (max_le_high, cub::Max ());
1219
1271
if (tx == 0 ) {
1220
- temp_storage.block_aggregate .pair = threadlocal_sum;
1272
+ temp_storage.block_aggregate .pairs [0 ] = aggregate_gt_pivot_0;
1273
+ temp_storage.block_aggregate .pairs [1 ] = aggregate_gt_pivot_1;
1221
1274
temp_storage.min_val = min_gt_low;
1222
1275
temp_storage.max_val = max_le_high;
1223
1276
}
1224
1277
__syncthreads ();
1225
- threadlocal_sum = temp_storage.block_aggregate .pair ;
1278
+ aggregate_gt_pivot_0 = temp_storage.block_aggregate .pairs [0 ];
1279
+ aggregate_gt_pivot_1 = temp_storage.block_aggregate .pairs [1 ];
1226
1280
min_gt_low = temp_storage.min_val ;
1227
1281
max_le_high = temp_storage.max_val ;
1228
- if (threadlocal_sum.count >= k) {
1229
- low = mid;
1230
- sum_low = float (threadlocal_sum.value );
1282
+
1283
+ if (aggregate_gt_pivot_1.count >= k) {
1284
+ low = pivot_1;
1285
+ sum_low = float (aggregate_gt_pivot_1.value );
1286
+ } else if (aggregate_gt_pivot_0.count >= k) {
1287
+ low = pivot_0;
1288
+ high = min (pivot_1, max_le_high);
1289
+ sum_low = float (aggregate_gt_pivot_0.value );
1231
1290
} else {
1232
- high = min (mid , max_le_high);
1291
+ high = min (pivot_0 , max_le_high);
1233
1292
}
1234
1293
} while (min_gt_low != max_le_high);
1235
1294
0 commit comments