Skip to content

Commit 0dce178

Browse files
authored
misc: improve error handling of sampling kernels (#456)
Add an option to check whether there are nan inputs. This PR also removes all `eps` arguments in renorm kernels: previously we pre-set a eps constant to determine when to stop the binary search, however, this might not be accuracy when vocabulary size grows (e.g. >= 1e6 in llama3 where our eps might be set to 1e-5). In this PR, we implement a loop variant which do not rely on any external eps, and it can help us address some of the issues such as vllm-project/vllm#7137 (comment) .
1 parent 0d61871 commit 0dce178

File tree

5 files changed

+172
-63
lines changed

5 files changed

+172
-63
lines changed

include/flashinfer/sampling.cuh

+104-30
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
#ifndef FLASHINFER_SAMPLING_CUH_
1717
#define FLASHINFER_SAMPLING_CUH_
1818

19-
#include <driver_types.h>
20-
2119
#include <cub/block/block_adjacent_difference.cuh>
2220
#include <cub/block/block_reduce.cuh>
2321
#include <cub/block/block_scan.cuh>
@@ -347,13 +345,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
347345
}
348346
__syncthreads();
349347
if (tx == 0) {
348+
output[bx] = sampled_id;
350349
if (temp_storage.data.block_aggregate.pair.count >= k) {
351350
// failed to sample within MAX_TOP_P_ROUNDS
352351
if (success != nullptr) {
353352
success[bx] = false;
354353
}
355354
} else {
356-
output[bx] = sampled_id;
357355
if (success != nullptr) {
358356
success[bx] = true;
359357
}
@@ -433,13 +431,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
433431
}
434432
__syncthreads();
435433
if (tx == 0) {
434+
output[bx] = sampled_id;
436435
if (float(q) >= top_p) {
437436
// failed to sample within MAX_TOP_P_ROUNDS
438437
if (success != nullptr) {
439438
success[bx] = false;
440439
}
441440
} else {
442-
output[bx] = sampled_id;
443441
if (success != nullptr) {
444442
success[bx] = true;
445443
}
@@ -539,13 +537,13 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
539537
}
540538
__syncthreads();
541539
if (tx == 0) {
540+
output[bx] = sampled_id;
542541
if (pivot < scaled_p) {
543542
// failed to sample within MAX_ROUNDS
544543
if (success != nullptr) {
545544
success[bx] = false;
546545
}
547546
} else {
548-
output[bx] = sampled_id;
549547
if (success != nullptr) {
550548
success[bx] = true;
551549
}
@@ -627,13 +625,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
627625
}
628626
__syncthreads();
629627
if (tx == 0) {
628+
output[bx] = sampled_id;
630629
if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) {
631630
// failed to sample within MAX_TOP_P_ROUNDS
632631
if (success != nullptr) {
633632
success[bx] = false;
634633
}
635634
} else {
636-
output[bx] = sampled_id;
637635
if (success != nullptr) {
638636
success[bx] = true;
639637
}
@@ -808,7 +806,7 @@ struct RenormTempStorage {
808806
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
809807
typename DType>
810808
__global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* top_p_arr,
811-
float top_p_val, float eps, uint32_t d) {
809+
float top_p_val, uint32_t d) {
812810
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
813811
const uint32_t row_idx = bx;
814812
float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
@@ -844,12 +842,20 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
844842
threadlocal_max_val = temp_storage.data.max_val;
845843

846844
float low = 0, high = threadlocal_max_val;
845+
DType min_gt_low, max_le_high;
847846
DType sum_low(1);
848-
// f(x) = probs[probs > x], f(x) is non-increasing
849-
// loop invariant: f(low) >= p, f(high) < p
850-
while (high - low > eps) {
847+
// f(x) = sum(probs[probs > x]), f(x) is non-increasing
848+
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
849+
// loop invariant:
850+
// - f(low) >= p, f(high) < p
851+
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
852+
// stopping condition
853+
// - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
854+
do {
851855
DType threadlocal_sum(0);
852856
float mid = (low + high) / 2;
857+
min_gt_low = high;
858+
max_le_high = low;
853859
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
854860
probs_vec.fill(DType(0));
855861
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -858,26 +864,42 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
858864
#pragma unroll
859865
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
860866
probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : DType(0);
867+
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
868+
min_gt_low = min(min_gt_low, probs_vec[j]);
869+
}
870+
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
871+
max_le_high = max(max_le_high, probs_vec[j]);
872+
}
861873
}
862874
threadlocal_sum +=
863875
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
864876
.Sum<VEC_SIZE>(probs_greater_than_pivot);
865877
__syncthreads();
866878
}
879+
min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
880+
.Reduce(min_gt_low, cub::Min());
881+
__syncthreads();
882+
max_le_high =
883+
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
884+
.Reduce(max_le_high, cub::Max());
867885
if (tx == 0) {
868886
temp_storage.data.block_aggregate.value = threadlocal_sum;
887+
temp_storage.data.min_val = min_gt_low;
888+
temp_storage.data.max_val = max_le_high;
869889
}
870890
__syncthreads();
871891
threadlocal_sum = temp_storage.data.block_aggregate.value;
892+
min_gt_low = temp_storage.data.min_val;
893+
max_le_high = temp_storage.data.max_val;
872894
if (threadlocal_sum >= p) {
873895
low = mid;
874896
sum_low = float(threadlocal_sum);
875897
} else {
876-
high = mid;
898+
high = min(mid, max_le_high);
877899
}
878-
}
900+
} while (min_gt_low != max_le_high);
879901

880-
DType normalizer = math::ptx_rcp(max(sum_low, eps));
902+
DType normalizer = math::ptx_rcp(max(sum_low, 1e-8));
881903

882904
// normalize
883905
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -898,7 +920,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
898920
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
899921
typename DType, typename IdType>
900922
__global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType* top_k_arr,
901-
uint32_t top_k_val, float eps, uint32_t d) {
923+
uint32_t top_k_val, uint32_t d) {
902924
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
903925
const uint32_t row_idx = bx;
904926
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
@@ -941,12 +963,20 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
941963
threadlocal_min_val = temp_storage.data.min_val;
942964

943965
float low = threadlocal_min_val - 1, high = threadlocal_max_val;
966+
DType min_gt_low, max_le_high;
944967
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
945-
// loop invariant: f(low) >= k, f(high) < k
946-
while (high - low > eps) {
968+
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
969+
// loop invariant:
970+
// - f(low) >= k, f(high) < k
971+
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
972+
// stopping condition: min_gt_low == max_le_high
973+
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
974+
do {
947975
int threadlocal_count_sum = 0;
948976
int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
949977
float mid = (low + high) / 2;
978+
min_gt_low = high;
979+
max_le_high = low;
950980
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
951981
logits_vec.fill(DType(0));
952982
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -956,23 +986,41 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
956986
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
957987
probs_greater_than_pivot_count[j] =
958988
logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
989+
if (logits_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
990+
min_gt_low = min(min_gt_low, logits_vec[j]);
991+
}
992+
if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
993+
max_le_high = max(max_le_high, logits_vec[j]);
994+
}
959995
}
960996
threadlocal_count_sum +=
961997
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce_int)
962998
.Sum<VEC_SIZE>(probs_greater_than_pivot_count);
963999
__syncthreads();
9641000
}
1001+
min_gt_low =
1002+
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1003+
.Reduce(min_gt_low, cub::Min());
1004+
__syncthreads();
1005+
max_le_high =
1006+
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1007+
.Reduce(max_le_high, cub::Max());
1008+
__syncthreads();
9651009
if (tx == 0) {
9661010
temp_storage.data.block_aggregate.count = threadlocal_count_sum;
1011+
temp_storage.data.min_val = min_gt_low;
1012+
temp_storage.data.max_val = max_le_high;
9671013
}
9681014
__syncthreads();
9691015
threadlocal_count_sum = temp_storage.data.block_aggregate.count;
1016+
min_gt_low = temp_storage.data.min_val;
1017+
max_le_high = temp_storage.data.max_val;
9701018
if (threadlocal_count_sum >= k) {
9711019
low = mid;
9721020
} else {
973-
high = mid;
1021+
high = min(mid, max_le_high);
9741022
}
975-
}
1023+
} while (min_gt_low != max_le_high);
9761024
pivot = low;
9771025
}
9781026

@@ -996,7 +1044,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
9961044
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
9971045
typename DType, typename IdType>
9981046
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr,
999-
uint32_t top_k_val, float eps, uint32_t d) {
1047+
uint32_t top_k_val, uint32_t d) {
10001048
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
10011049
const uint32_t row_idx = bx;
10021050
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
@@ -1033,13 +1081,21 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10331081
threadlocal_max_val = temp_storage.data.max_val;
10341082

10351083
float low = 0, high = threadlocal_max_val;
1084+
DType min_gt_low, max_le_high;
10361085
DType sum_low(1);
10371086
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
1038-
// loop invariant: f(low) >= k, f(high) < k
1039-
while (high - low > eps) {
1087+
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
1088+
// loop invariant:
1089+
// - f(low) >= k, f(high) < k
1090+
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
1091+
// stopping condition: min_gt_low == max_le_high
1092+
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
1093+
do {
10401094
Pair<DType> threadlocal_sum{DType(0), 0};
10411095
Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
10421096
float mid = (low + high) / 2;
1097+
min_gt_low = high;
1098+
max_le_high = low;
10431099
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
10441100
probs_vec.fill(DType(0));
10451101
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1050,26 +1106,44 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10501106
probs_greater_than_pivot_pair[j] = {
10511107
(probs_vec[j] > mid) ? probs_vec[j] : DType(0),
10521108
(probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1109+
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1110+
min_gt_low = min(min_gt_low, probs_vec[j]);
1111+
}
1112+
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1113+
max_le_high = max(max_le_high, probs_vec[j]);
1114+
}
10531115
}
10541116
threadlocal_sum += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
10551117
temp_storage.block_prim.reduce_pair)
10561118
.Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
10571119
__syncthreads();
10581120
}
1121+
min_gt_low =
1122+
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1123+
.Reduce(min_gt_low, cub::Min());
1124+
__syncthreads();
1125+
max_le_high =
1126+
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1127+
.Reduce(max_le_high, cub::Max());
1128+
__syncthreads();
10591129
if (tx == 0) {
10601130
temp_storage.data.block_aggregate.pair = threadlocal_sum;
1131+
temp_storage.data.min_val = min_gt_low;
1132+
temp_storage.data.max_val = max_le_high;
10611133
}
10621134
__syncthreads();
10631135
threadlocal_sum = temp_storage.data.block_aggregate.pair;
1136+
min_gt_low = temp_storage.data.min_val;
1137+
max_le_high = temp_storage.data.max_val;
10641138
if (threadlocal_sum.count >= k) {
10651139
low = mid;
10661140
sum_low = float(threadlocal_sum.value);
10671141
} else {
1068-
high = mid;
1142+
high = min(mid, max_le_high);
10691143
}
1070-
}
1144+
} while (min_gt_low != max_le_high);
10711145

1072-
normalizer = math::ptx_rcp(max(sum_low, eps));
1146+
normalizer = math::ptx_rcp(max(sum_low, 1e-8));
10731147
pivot = low;
10741148
}
10751149

@@ -1090,7 +1164,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10901164
}
10911165

10921166
template <typename DType>
1093-
cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, float eps,
1167+
cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
10941168
uint32_t batch_size, float top_p_val, uint32_t d,
10951169
cudaStream_t stream = 0) {
10961170
const uint32_t BLOCK_THREADS = 1024;
@@ -1099,7 +1173,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
10991173
const uint32_t smem_size = sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
11001174
dim3 nblks(batch_size);
11011175
dim3 nthrs(BLOCK_THREADS);
1102-
void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &eps, &d};
1176+
void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
11031177
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
11041178
auto kernel = TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
11051179
FLASHINFER_CUDA_CALL(
@@ -1110,7 +1184,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
11101184
}
11111185

11121186
template <typename DType, typename IdType>
1113-
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, float eps,
1187+
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
11141188
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
11151189
cudaStream_t stream = 0) {
11161190
const uint32_t BLOCK_THREADS = 1024;
@@ -1119,7 +1193,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
11191193
const uint32_t smem_size = sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
11201194
dim3 nblks(batch_size);
11211195
dim3 nthrs(BLOCK_THREADS);
1122-
void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &eps, &d};
1196+
void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
11231197
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
11241198
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
11251199
FLASHINFER_CUDA_CALL(
@@ -1130,7 +1204,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
11301204
}
11311205

11321206
template <typename DType, typename IdType>
1133-
cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, float eps,
1207+
cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr,
11341208
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
11351209
cudaStream_t stream = 0) {
11361210
const uint32_t BLOCK_THREADS = 1024;
@@ -1139,7 +1213,7 @@ cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_ar
11391213
const uint32_t smem_size = sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
11401214
dim3 nblks(batch_size);
11411215
dim3 nthrs(BLOCK_THREADS);
1142-
void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &eps, &d};
1216+
void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
11431217
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
11441218
auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
11451219
FLASHINFER_CUDA_CALL(

python/csrc/flashinfer_ops.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
5959
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic);
6060

6161
torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_p_arr,
62-
double top_p_val, double eps);
62+
double top_p_val);
6363

6464
torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_k_arr,
65-
unsigned int top_k_val, double eps);
65+
unsigned int top_k_val);
6666

6767
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
68-
unsigned int top_k_val, double eps);
68+
unsigned int top_k_val);
6969

7070
std::vector<torch::Tensor> chain_speculative_sampling(
7171
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,

0 commit comments

Comments
 (0)