Skip to content

Commit 2b9f16e

Browse files
committedMar 28, 2025
Revert "bugfix: Ensure Loop Termination by Enforcing IEEE-754 Compliance in Sampling Kernels (flashinfer-ai#774)"
This reverts commit a0443d5.
1 parent 79fd1ae commit 2b9f16e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed
 

‎include/flashinfer/sampling.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
827827
__syncthreads();
828828
threadlocal_max_val = temp_storage.max_val;
829829

830-
double low = 0, high = threadlocal_max_val;
830+
float low = 0, high = threadlocal_max_val;
831831
DType min_gt_low, max_le_high;
832832
DType sum_low(1);
833833
// f(x) = sum(probs[probs > x]), f(x) is non-increasing
@@ -839,7 +839,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
839839
// - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
840840
do {
841841
DType threadlocal_sum(0);
842-
double mid = (low + high) / 2;
842+
float mid = (low + high) / 2;
843843
min_gt_low = high;
844844
max_le_high = low;
845845
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -949,7 +949,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
949949
threadlocal_max_val = temp_storage.max_val;
950950
threadlocal_min_val = temp_storage.min_val;
951951

952-
double low = threadlocal_min_val - 1, high = threadlocal_max_val;
952+
float low = threadlocal_min_val - 1, high = threadlocal_max_val;
953953
DType min_gt_low, max_le_high;
954954
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
955955
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
@@ -961,7 +961,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
961961
do {
962962
int threadlocal_count_sum = 0;
963963
int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
964-
double mid = (low + high) / 2;
964+
float mid = (low + high) / 2;
965965
min_gt_low = high;
966966
max_le_high = low;
967967
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -1067,7 +1067,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10671067
__syncthreads();
10681068
threadlocal_max_val = temp_storage.max_val;
10691069

1070-
double low = 0, high = threadlocal_max_val;
1070+
float low = 0, high = threadlocal_max_val;
10711071
DType min_gt_low, max_le_high;
10721072
DType sum_low(1);
10731073
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
@@ -1080,7 +1080,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10801080
do {
10811081
Pair<DType> threadlocal_sum{DType(0), 0};
10821082
Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
1083-
double mid = (low + high) / 2;
1083+
float mid = (low + high) / 2;
10841084
min_gt_low = high;
10851085
max_le_high = low;
10861086
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {

0 commit comments

Comments
 (0)
Please sign in to comment.