Skip to content

perf: use max probability instead of 1 as upper bound in top-p/k sampling #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 103 additions & 100 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cub/block/block_scan.cuh>
#include <cuda/std/limits>
#include <numeric>
#include <tuple>

#include "math.cuh"
#include "utils.cuh"
Expand Down Expand Up @@ -97,10 +98,10 @@ struct SamplingTempStorage {
struct {
int32_t sampled_id;
int32_t last_valid_id;
float max_val;
union {
float value;
ValueCount<float> pair;
float max_p;
} block_aggregate;
};
};
Expand Down Expand Up @@ -190,6 +191,75 @@ __device__ __forceinline__ void DeterministicInclusiveSum(
}
}

template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
typename TempStorage>
__device__ __forceinline__ std::tuple<float, float> GetMinMaxValue(float* in_data, uint32_t row_idx,
uint32_t d,
TempStorage& temp_storage) {
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;
float max_val = -cuda::std::numeric_limits<float>::infinity(),
min_val = cuda::std::numeric_limits<float>::infinity();
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
__syncthreads();
min_val = min(
min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Min()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = max_val;
temp_storage.min_val = min_val;
}
__syncthreads();
max_val = temp_storage.max_val;
min_val = temp_storage.min_val;

return std::make_tuple(min_val, max_val);
}

template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
typename TempStorage>
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
TempStorage& temp_storage) {
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;

float max_val = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = max_val;
}
__syncthreads();
return temp_storage.max_val;
}

template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
__device__ __forceinline__ void DeviceSamplingFromProb(
Expand Down Expand Up @@ -335,10 +405,14 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = 1;
double low = 0, high = max_val;
int sampled_id;
int round = 0;
do {
Expand Down Expand Up @@ -448,10 +522,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = 1;
double low = 0, high = max_val;
int sampled_id;
do {
temp_storage.sampled_id = d;
Expand Down Expand Up @@ -552,29 +630,12 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

vec_t<float, VEC_SIZE> probs_vec;

float max_p = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float probs_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_[j] = probs_vec[j];
}
max_p = max(max_p, BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(probs_, cub::Max()));
__syncthreads();
}
if (tx == 0) {
temp_storage.block_aggregate.max_p = max_p;
}
__syncthreads();
float pivot = temp_storage.block_aggregate.max_p * p;
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
float pivot = max_val * p;

vec_t<float, VEC_SIZE> probs_vec;
float aggregate_gt_pivot = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
Expand Down Expand Up @@ -648,10 +709,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = 1;
double low = 0, high = max_val;
int sampled_id;
do {
temp_storage.sampled_id = d;
Expand Down Expand Up @@ -904,29 +969,11 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
vec_t<float, VEC_SIZE> probs_vec;
float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0

float threadlocal_max_val = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_greater_than_pivot[j] = probs_vec[j];
}
threadlocal_max_val =
max(threadlocal_max_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = threadlocal_max_val;
}
__syncthreads();
threadlocal_max_val = temp_storage.max_val;
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(probs, row_idx, d,
temp_storage);

double low = 0, high = threadlocal_max_val;
double low = 0, high = max_val;
float min_gt_low, max_le_high;
float sum_low = 1;
// f(x) = sum(probs[probs > x]), f(x) is non-increasing
Expand Down Expand Up @@ -1019,37 +1066,11 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
float logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0

float threadlocal_max_val = -cuda::std::numeric_limits<float>::infinity(),
threadlocal_min_val = cuda::std::numeric_limits<float>::infinity();
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
logits_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
logits_greater_than_pivot[j] = logits_vec[j];
}
threadlocal_max_val =
max(threadlocal_max_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Max()));
__syncthreads();
threadlocal_min_val =
min(threadlocal_min_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Min()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = threadlocal_max_val;
temp_storage.min_val = threadlocal_min_val;
}
__syncthreads();
threadlocal_max_val = temp_storage.max_val;
threadlocal_min_val = temp_storage.min_val;
auto [min_val, max_val] = GetMinMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
logits, row_idx, d, temp_storage);

double low = threadlocal_min_val - 1, high = threadlocal_max_val;
double low = min_val - 1, high = max_val;
float min_gt_low, max_le_high;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
Expand Down Expand Up @@ -1144,29 +1165,11 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
temp_storage.max_val = 0;
float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0

float threadlocal_max_val = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_greater_than_pivot[j] = probs_vec[j];
}
threadlocal_max_val =
max(threadlocal_max_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = threadlocal_max_val;
}
__syncthreads();
threadlocal_max_val = temp_storage.max_val;
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

double low = 0, high = threadlocal_max_val;
double low = 0, high = max_val;
float min_gt_low, max_le_high;
float sum_low = 1;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
Expand Down