Skip to content

Commit e3853dd

Browse files
authored
perf: use max probability instead of 1 as upper bound in top-p/k sampling (#925)
Reduce maximum number of iterations.
1 parent fb578e7 commit e3853dd

File tree

1 file changed

+103
-100
lines changed

1 file changed

+103
-100
lines changed

include/flashinfer/sampling.cuh

+103-100
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <cub/block/block_scan.cuh>
2626
#include <cuda/std/limits>
2727
#include <numeric>
28+
#include <tuple>
2829

2930
#include "math.cuh"
3031
#include "utils.cuh"
@@ -97,10 +98,10 @@ struct SamplingTempStorage {
9798
struct {
9899
int32_t sampled_id;
99100
int32_t last_valid_id;
101+
float max_val;
100102
union {
101103
float value;
102104
ValueCount<float> pair;
103-
float max_p;
104105
} block_aggregate;
105106
};
106107
};
@@ -190,6 +191,75 @@ __device__ __forceinline__ void DeterministicInclusiveSum(
190191
}
191192
}
192193

194+
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
195+
typename TempStorage>
196+
__device__ __forceinline__ std::tuple<float, float> GetMinMaxValue(float* in_data, uint32_t row_idx,
197+
uint32_t d,
198+
TempStorage& temp_storage) {
199+
const uint32_t tx = threadIdx.x;
200+
vec_t<float, VEC_SIZE> in_data_vec;
201+
float max_val = -cuda::std::numeric_limits<float>::infinity(),
202+
min_val = cuda::std::numeric_limits<float>::infinity();
203+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
204+
in_data_vec.fill(0);
205+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
206+
in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
207+
}
208+
float in_data_[VEC_SIZE];
209+
#pragma unroll
210+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
211+
in_data_[j] = in_data_vec[j];
212+
}
213+
max_val = max(
214+
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
215+
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
216+
__syncthreads();
217+
min_val = min(
218+
min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
219+
.Reduce<VEC_SIZE>(in_data_, cub::Min()));
220+
__syncthreads();
221+
}
222+
if (tx == 0) {
223+
temp_storage.max_val = max_val;
224+
temp_storage.min_val = min_val;
225+
}
226+
__syncthreads();
227+
max_val = temp_storage.max_val;
228+
min_val = temp_storage.min_val;
229+
230+
return std::make_tuple(min_val, max_val);
231+
}
232+
233+
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
234+
typename TempStorage>
235+
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
236+
TempStorage& temp_storage) {
237+
const uint32_t tx = threadIdx.x;
238+
vec_t<float, VEC_SIZE> in_data_vec;
239+
240+
float max_val = 0;
241+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
242+
in_data_vec.fill(0);
243+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
244+
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
245+
}
246+
float in_data_[VEC_SIZE];
247+
#pragma unroll
248+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
249+
in_data_[j] = in_data_vec[j];
250+
}
251+
max_val = max(
252+
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
253+
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
254+
__syncthreads();
255+
}
256+
if (tx == 0) {
257+
temp_storage.max_val = max_val;
258+
}
259+
__syncthreads();
260+
return temp_storage.max_val;
261+
}
262+
193263
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
194264
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
195265
__device__ __forceinline__ void DeviceSamplingFromProb(
@@ -335,10 +405,14 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
335405
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
336406
smem_sampling);
337407

408+
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
409+
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
410+
probs, row_idx, d, temp_storage);
411+
338412
vec_t<float, VEC_SIZE> probs_vec;
339413
float aggregate;
340414
float q = 1;
341-
double low = 0, high = 1;
415+
double low = 0, high = max_val;
342416
int sampled_id;
343417
int round = 0;
344418
do {
@@ -448,10 +522,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
448522
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
449523
smem_sampling);
450524

525+
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
526+
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
527+
probs, row_idx, d, temp_storage);
528+
451529
vec_t<float, VEC_SIZE> probs_vec;
452530
float aggregate;
453531
float q = 1;
454-
double low = 0, high = 1;
532+
double low = 0, high = max_val;
455533
int sampled_id;
456534
do {
457535
temp_storage.sampled_id = d;
@@ -552,29 +630,12 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
552630
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
553631
smem_sampling);
554632

555-
vec_t<float, VEC_SIZE> probs_vec;
556-
557-
float max_p = 0;
558-
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
559-
probs_vec.fill(0);
560-
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
561-
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
562-
}
563-
float probs_[VEC_SIZE];
564-
#pragma unroll
565-
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
566-
probs_[j] = probs_vec[j];
567-
}
568-
max_p = max(max_p, BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
569-
.Reduce<VEC_SIZE>(probs_, cub::Max()));
570-
__syncthreads();
571-
}
572-
if (tx == 0) {
573-
temp_storage.block_aggregate.max_p = max_p;
574-
}
575-
__syncthreads();
576-
float pivot = temp_storage.block_aggregate.max_p * p;
633+
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
634+
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
635+
probs, row_idx, d, temp_storage);
636+
float pivot = max_val * p;
577637

638+
vec_t<float, VEC_SIZE> probs_vec;
578639
float aggregate_gt_pivot = 0;
579640
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
580641
probs_vec.fill(0);
@@ -648,10 +709,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
648709
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
649710
smem_sampling);
650711

712+
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
713+
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
714+
probs, row_idx, d, temp_storage);
715+
651716
vec_t<float, VEC_SIZE> probs_vec;
652717
float aggregate;
653718
float q = 1;
654-
double low = 0, high = 1;
719+
double low = 0, high = max_val;
655720
int sampled_id;
656721
do {
657722
temp_storage.sampled_id = d;
@@ -904,29 +969,11 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
904969
vec_t<float, VEC_SIZE> probs_vec;
905970
float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
906971

907-
float threadlocal_max_val = 0;
908-
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
909-
probs_vec.fill(0);
910-
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
911-
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
912-
}
913-
#pragma unroll
914-
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
915-
probs_greater_than_pivot[j] = probs_vec[j];
916-
}
917-
threadlocal_max_val =
918-
max(threadlocal_max_val,
919-
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
920-
.Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
921-
__syncthreads();
922-
}
923-
if (tx == 0) {
924-
temp_storage.max_val = threadlocal_max_val;
925-
}
926-
__syncthreads();
927-
threadlocal_max_val = temp_storage.max_val;
972+
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
973+
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(probs, row_idx, d,
974+
temp_storage);
928975

929-
double low = 0, high = threadlocal_max_val;
976+
double low = 0, high = max_val;
930977
float min_gt_low, max_le_high;
931978
float sum_low = 1;
932979
// f(x) = sum(probs[probs > x]), f(x) is non-increasing
@@ -1019,37 +1066,11 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
10191066
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
10201067
float logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
10211068

1022-
float threadlocal_max_val = -cuda::std::numeric_limits<float>::infinity(),
1023-
threadlocal_min_val = cuda::std::numeric_limits<float>::infinity();
1024-
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
1025-
logits_vec.fill(0);
1026-
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
1027-
logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
1028-
}
1029-
#pragma unroll
1030-
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
1031-
logits_greater_than_pivot[j] = logits_vec[j];
1032-
}
1033-
threadlocal_max_val =
1034-
max(threadlocal_max_val,
1035-
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1036-
.Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Max()));
1037-
__syncthreads();
1038-
threadlocal_min_val =
1039-
min(threadlocal_min_val,
1040-
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1041-
.Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Min()));
1042-
__syncthreads();
1043-
}
1044-
if (tx == 0) {
1045-
temp_storage.max_val = threadlocal_max_val;
1046-
temp_storage.min_val = threadlocal_min_val;
1047-
}
1048-
__syncthreads();
1049-
threadlocal_max_val = temp_storage.max_val;
1050-
threadlocal_min_val = temp_storage.min_val;
1069+
auto [min_val, max_val] = GetMinMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
1070+
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
1071+
logits, row_idx, d, temp_storage);
10511072

1052-
double low = threadlocal_min_val - 1, high = threadlocal_max_val;
1073+
double low = min_val - 1, high = max_val;
10531074
float min_gt_low, max_le_high;
10541075
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
10551076
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
@@ -1144,29 +1165,11 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
11441165
temp_storage.max_val = 0;
11451166
float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
11461167

1147-
float threadlocal_max_val = 0;
1148-
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
1149-
probs_vec.fill(0);
1150-
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
1151-
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
1152-
}
1153-
#pragma unroll
1154-
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
1155-
probs_greater_than_pivot[j] = probs_vec[j];
1156-
}
1157-
threadlocal_max_val =
1158-
max(threadlocal_max_val,
1159-
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1160-
.Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
1161-
__syncthreads();
1162-
}
1163-
if (tx == 0) {
1164-
temp_storage.max_val = threadlocal_max_val;
1165-
}
1166-
__syncthreads();
1167-
threadlocal_max_val = temp_storage.max_val;
1168+
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
1169+
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
1170+
probs, row_idx, d, temp_storage);
11681171

1169-
double low = 0, high = threadlocal_max_val;
1172+
double low = 0, high = max_val;
11701173
float min_gt_low, max_le_high;
11711174
float sum_low = 1;
11721175
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing

0 commit comments

Comments
 (0)