@@ -827,7 +827,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
827
827
__syncthreads ();
828
828
threadlocal_max_val = temp_storage.max_val ;
829
829
830
- double low = 0 , high = threadlocal_max_val;
830
+ float low = 0 , high = threadlocal_max_val;
831
831
DType min_gt_low, max_le_high;
832
832
DType sum_low (1 );
833
833
// f(x) = sum(probs[probs > x]), f(x) is non-increasing
@@ -839,7 +839,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
839
839
// - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
840
840
do {
841
841
DType threadlocal_sum (0 );
842
- double mid = (low + high) / 2 ;
842
+ float mid = (low + high) / 2 ;
843
843
min_gt_low = high;
844
844
max_le_high = low;
845
845
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -949,7 +949,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
949
949
threadlocal_max_val = temp_storage.max_val ;
950
950
threadlocal_min_val = temp_storage.min_val ;
951
951
952
- double low = threadlocal_min_val - 1 , high = threadlocal_max_val;
952
+ float low = threadlocal_min_val - 1 , high = threadlocal_max_val;
953
953
DType min_gt_low, max_le_high;
954
954
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
955
955
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
@@ -961,7 +961,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
961
961
do {
962
962
int threadlocal_count_sum = 0 ;
963
963
int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
964
- double mid = (low + high) / 2 ;
964
+ float mid = (low + high) / 2 ;
965
965
min_gt_low = high;
966
966
max_le_high = low;
967
967
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -1067,7 +1067,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1067
1067
__syncthreads ();
1068
1068
threadlocal_max_val = temp_storage.max_val ;
1069
1069
1070
- double low = 0 , high = threadlocal_max_val;
1070
+ float low = 0 , high = threadlocal_max_val;
1071
1071
DType min_gt_low, max_le_high;
1072
1072
DType sum_low (1 );
1073
1073
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
@@ -1080,7 +1080,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1080
1080
do {
1081
1081
Pair<DType> threadlocal_sum{DType (0 ), 0 };
1082
1082
Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
1083
- double mid = (low + high) / 2 ;
1083
+ float mid = (low + high) / 2 ;
1084
1084
min_gt_low = high;
1085
1085
max_le_high = low;
1086
1086
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
0 commit comments