Skip to content

Commit e19cb7b

Browse files
xslingcnyzh119
andauthored
perf: dual pivot top-p/top-k renorm (#974)
Uses one more pivot and a fixed Block Threads of 1024 in the sampling.renorm kernels. Benchmark and comparisons can be found [here](https://docs.google.com/spreadsheets/d/15Qxh7fRxO2HiYaSxpyZs7sWQKoJ-8gscQFTG3rW_hZI/edit?usp=sharing). --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent 588c2fb commit e19cb7b

File tree

1 file changed

+102
-43
lines changed

1 file changed

+102
-43
lines changed

include/flashinfer/sampling.cuh

+102-43
Original file line numberDiff line numberDiff line change
@@ -943,9 +943,15 @@ struct RenormTempStorage {
943943
float max_val;
944944
float min_val;
945945
union {
946-
float value;
947-
int count;
948-
ValueCount<float> pair;
946+
struct {
947+
float values[2];
948+
};
949+
struct {
950+
int counts[2];
951+
};
952+
struct {
953+
ValueCount<float> pairs[2];
954+
};
949955
} block_aggregate;
950956
};
951957
};
@@ -964,7 +970,6 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
964970
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
965971
temp_storage.max_val = 0;
966972
vec_t<float, VEC_SIZE> probs_vec;
967-
float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
968973

969974
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
970975
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(probs, row_idx, d,
@@ -981,8 +986,10 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
981986
// stopping condition
982987
// - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
983988
do {
984-
float threadlocal_sum = 0;
985-
double mid = (low + high) / 2;
989+
double pivot_0 = (high + 2 * low) / 3;
990+
double pivot_1 = (2 * high + low) / 3;
991+
992+
float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0;
986993
min_gt_low = high;
987994
max_le_high = low;
988995
#pragma unroll 2
@@ -991,19 +998,29 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
991998
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
992999
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
9931000
}
1001+
1002+
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
9941003
#pragma unroll
9951004
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
996-
probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : 0;
1005+
probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0;
1006+
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
1007+
9971008
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
9981009
min_gt_low = min(min_gt_low, probs_vec[j]);
9991010
}
10001011
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
10011012
max_le_high = max(max_le_high, probs_vec[j]);
10021013
}
10031014
}
1004-
threadlocal_sum +=
1015+
1016+
aggregate_gt_pivot_0 +=
1017+
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1018+
.Sum<VEC_SIZE>(probs_gt_pivot_0);
1019+
__syncthreads();
1020+
1021+
aggregate_gt_pivot_1 +=
10051022
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1006-
.Sum<VEC_SIZE>(probs_greater_than_pivot);
1023+
.Sum<VEC_SIZE>(probs_gt_pivot_1);
10071024
__syncthreads();
10081025
}
10091026
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
@@ -1013,19 +1030,26 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
10131030
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
10141031
.Reduce(max_le_high, cub::Max());
10151032
if (tx == 0) {
1016-
temp_storage.block_aggregate.value = threadlocal_sum;
1033+
temp_storage.block_aggregate.values[0] = aggregate_gt_pivot_0;
1034+
temp_storage.block_aggregate.values[1] = aggregate_gt_pivot_1;
10171035
temp_storage.min_val = min_gt_low;
10181036
temp_storage.max_val = max_le_high;
10191037
}
10201038
__syncthreads();
1021-
threadlocal_sum = temp_storage.block_aggregate.value;
1039+
aggregate_gt_pivot_0 = temp_storage.block_aggregate.values[0];
1040+
aggregate_gt_pivot_1 = temp_storage.block_aggregate.values[1];
10221041
min_gt_low = temp_storage.min_val;
10231042
max_le_high = temp_storage.max_val;
1024-
if (threadlocal_sum >= p) {
1025-
low = mid;
1026-
sum_low = threadlocal_sum;
1043+
1044+
if (aggregate_gt_pivot_1 >= p) {
1045+
low = pivot_1;
1046+
sum_low = aggregate_gt_pivot_1;
1047+
} else if (aggregate_gt_pivot_0 >= p) {
1048+
low = pivot_0;
1049+
high = min(pivot_1, max_le_high);
1050+
sum_low = aggregate_gt_pivot_0;
10271051
} else {
1028-
high = min(mid, max_le_high);
1052+
high = min(pivot_0, max_le_high);
10291053
}
10301054
} while (min_gt_low != max_le_high);
10311055

@@ -1079,9 +1103,10 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
10791103
// stopping condition: min_gt_low == max_le_high
10801104
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
10811105
do {
1082-
int threadlocal_count_sum = 0;
1083-
int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
1084-
double mid = (low + high) / 2;
1106+
double pivot_0 = (high + 2 * low) / 3;
1107+
double pivot_1 = (2 * high + low) / 3;
1108+
1109+
int aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0;
10851110
min_gt_low = high;
10861111
max_le_high = low;
10871112
#pragma unroll 2
@@ -1090,20 +1115,30 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
10901115
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
10911116
logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
10921117
}
1118+
int probs_gt_pivot_0_count[VEC_SIZE], probs_gt_pivot_1_count[VEC_SIZE];
10931119
#pragma unroll
10941120
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
1095-
probs_greater_than_pivot_count[j] =
1096-
logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
1121+
probs_gt_pivot_0_count[j] =
1122+
logits_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
1123+
probs_gt_pivot_1_count[j] =
1124+
logits_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
1125+
10971126
if (logits_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
10981127
min_gt_low = min(min_gt_low, logits_vec[j]);
10991128
}
11001129
if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
11011130
max_le_high = max(max_le_high, logits_vec[j]);
11021131
}
11031132
}
1104-
threadlocal_count_sum +=
1133+
1134+
aggregate_gt_pivot_0 +=
1135+
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce_int)
1136+
.Sum<VEC_SIZE>(probs_gt_pivot_0_count);
1137+
__syncthreads();
1138+
1139+
aggregate_gt_pivot_1 +=
11051140
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce_int)
1106-
.Sum<VEC_SIZE>(probs_greater_than_pivot_count);
1141+
.Sum<VEC_SIZE>(probs_gt_pivot_1_count);
11071142
__syncthreads();
11081143
}
11091144
min_gt_low =
@@ -1114,18 +1149,24 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
11141149
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
11151150
.Reduce(max_le_high, cub::Max());
11161151
if (tx == 0) {
1117-
temp_storage.block_aggregate.count = threadlocal_count_sum;
1152+
temp_storage.block_aggregate.counts[0] = aggregate_gt_pivot_0;
1153+
temp_storage.block_aggregate.counts[1] = aggregate_gt_pivot_1;
11181154
temp_storage.min_val = min_gt_low;
11191155
temp_storage.max_val = max_le_high;
11201156
}
11211157
__syncthreads();
1122-
threadlocal_count_sum = temp_storage.block_aggregate.count;
1158+
aggregate_gt_pivot_0 = temp_storage.block_aggregate.counts[0];
1159+
aggregate_gt_pivot_1 = temp_storage.block_aggregate.counts[1];
11231160
min_gt_low = temp_storage.min_val;
11241161
max_le_high = temp_storage.max_val;
1125-
if (threadlocal_count_sum >= k) {
1126-
low = mid;
1162+
1163+
if (aggregate_gt_pivot_1 >= k) {
1164+
low = pivot_1;
1165+
} else if (aggregate_gt_pivot_0 >= k) {
1166+
low = pivot_0;
1167+
high = min(pivot_1, max_le_high);
11271168
} else {
1128-
high = min(mid, max_le_high);
1169+
high = min(pivot_0, max_le_high);
11291170
}
11301171
} while (min_gt_low != max_le_high);
11311172
pivot = low;
@@ -1164,7 +1205,6 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
11641205
auto& temp_storage =
11651206
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
11661207
temp_storage.max_val = 0;
1167-
float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
11681208

11691209
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
11701210
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
@@ -1181,9 +1221,10 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
11811221
// stopping condition: min_gt_low == max_le_high
11821222
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
11831223
do {
1184-
ValueCount<float> threadlocal_sum{0, 0};
1185-
ValueCount<float> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
1186-
double mid = (low + high) / 2;
1224+
double pivot_0 = (high + 2 * low) / 3;
1225+
double pivot_1 = (2 * high + low) / 3;
1226+
1227+
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
11871228
min_gt_low = high;
11881229
max_le_high = low;
11891230
#pragma unroll 2
@@ -1192,21 +1233,32 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
11921233
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
11931234
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
11941235
}
1236+
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
11951237
#pragma unroll
11961238
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
1197-
probs_greater_than_pivot_pair[j] = {
1198-
(probs_vec[j] > mid) ? probs_vec[j] : 0,
1199-
(probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1239+
probs_gt_pivot_0_pair[j] = {
1240+
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
1241+
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1242+
probs_gt_pivot_1_pair[j] = {
1243+
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
1244+
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1245+
12001246
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
12011247
min_gt_low = min(min_gt_low, probs_vec[j]);
12021248
}
12031249
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
12041250
max_le_high = max(max_le_high, probs_vec[j]);
12051251
}
12061252
}
1207-
threadlocal_sum += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
1208-
temp_storage.block_prim.reduce_value_count)
1209-
.Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
1253+
1254+
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
1255+
temp_storage.block_prim.reduce_value_count)
1256+
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
1257+
__syncthreads();
1258+
1259+
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
1260+
temp_storage.block_prim.reduce_value_count)
1261+
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
12101262
__syncthreads();
12111263
}
12121264
min_gt_low =
@@ -1217,19 +1269,26 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
12171269
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
12181270
.Reduce(max_le_high, cub::Max());
12191271
if (tx == 0) {
1220-
temp_storage.block_aggregate.pair = threadlocal_sum;
1272+
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
1273+
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
12211274
temp_storage.min_val = min_gt_low;
12221275
temp_storage.max_val = max_le_high;
12231276
}
12241277
__syncthreads();
1225-
threadlocal_sum = temp_storage.block_aggregate.pair;
1278+
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
1279+
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
12261280
min_gt_low = temp_storage.min_val;
12271281
max_le_high = temp_storage.max_val;
1228-
if (threadlocal_sum.count >= k) {
1229-
low = mid;
1230-
sum_low = float(threadlocal_sum.value);
1282+
1283+
if (aggregate_gt_pivot_1.count >= k) {
1284+
low = pivot_1;
1285+
sum_low = float(aggregate_gt_pivot_1.value);
1286+
} else if (aggregate_gt_pivot_0.count >= k) {
1287+
low = pivot_0;
1288+
high = min(pivot_1, max_le_high);
1289+
sum_low = float(aggregate_gt_pivot_0.value);
12311290
} else {
1232-
high = min(mid, max_le_high);
1291+
high = min(pivot_0, max_le_high);
12331292
}
12341293
} while (min_gt_low != max_le_high);
12351294

0 commit comments

Comments
 (0)