16
16
#ifndef FLASHINFER_SAMPLING_CUH_
17
17
#define FLASHINFER_SAMPLING_CUH_
18
18
19
- #include < driver_types.h>
20
-
21
19
#include < cub/block/block_adjacent_difference.cuh>
22
20
#include < cub/block/block_reduce.cuh>
23
21
#include < cub/block/block_scan.cuh>
@@ -347,13 +345,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
347
345
}
348
346
__syncthreads ();
349
347
if (tx == 0 ) {
348
+ output[bx] = sampled_id;
350
349
if (temp_storage.data .block_aggregate .pair .count >= k) {
351
350
// failed to sample within MAX_TOP_P_ROUNDS
352
351
if (success != nullptr ) {
353
352
success[bx] = false ;
354
353
}
355
354
} else {
356
- output[bx] = sampled_id;
357
355
if (success != nullptr ) {
358
356
success[bx] = true ;
359
357
}
@@ -433,13 +431,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
433
431
}
434
432
__syncthreads ();
435
433
if (tx == 0 ) {
434
+ output[bx] = sampled_id;
436
435
if (float (q) >= top_p) {
437
436
// failed to sample within MAX_TOP_P_ROUNDS
438
437
if (success != nullptr ) {
439
438
success[bx] = false ;
440
439
}
441
440
} else {
442
- output[bx] = sampled_id;
443
441
if (success != nullptr ) {
444
442
success[bx] = true ;
445
443
}
@@ -539,13 +537,13 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
539
537
}
540
538
__syncthreads ();
541
539
if (tx == 0 ) {
540
+ output[bx] = sampled_id;
542
541
if (pivot < scaled_p) {
543
542
// failed to sample within MAX_ROUNDS
544
543
if (success != nullptr ) {
545
544
success[bx] = false ;
546
545
}
547
546
} else {
548
- output[bx] = sampled_id;
549
547
if (success != nullptr ) {
550
548
success[bx] = true ;
551
549
}
@@ -627,13 +625,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
627
625
}
628
626
__syncthreads ();
629
627
if (tx == 0 ) {
628
+ output[bx] = sampled_id;
630
629
if (temp_storage.data .block_aggregate .pair .count >= k || float (q) >= p) {
631
630
// failed to sample within MAX_TOP_P_ROUNDS
632
631
if (success != nullptr ) {
633
632
success[bx] = false ;
634
633
}
635
634
} else {
636
- output[bx] = sampled_id;
637
635
if (success != nullptr ) {
638
636
success[bx] = true ;
639
637
}
@@ -808,7 +806,7 @@ struct RenormTempStorage {
808
806
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
809
807
typename DType>
810
808
__global__ void TopPRenormProbKernel (DType* probs, DType* renormed_prob, DType* top_p_arr,
811
- float top_p_val, float eps, uint32_t d) {
809
+ float top_p_val, uint32_t d) {
812
810
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
813
811
const uint32_t row_idx = bx;
814
812
float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
@@ -844,12 +842,20 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
844
842
threadlocal_max_val = temp_storage.data .max_val ;
845
843
846
844
float low = 0 , high = threadlocal_max_val;
845
+ DType min_gt_low, max_le_high;
847
846
DType sum_low (1 );
848
- // f(x) = probs[probs > x], f(x) is non-increasing
849
- // loop invariant: f(low) >= p, f(high) < p
850
- while (high - low > eps) {
847
+ // f(x) = sum(probs[probs > x]), f(x) is non-increasing
848
+ // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
849
+ // loop invariant:
850
+ // - f(low) >= p, f(high) < p
851
+ // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
852
+ // stopping condition
853
+ // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
854
+ do {
851
855
DType threadlocal_sum (0 );
852
856
float mid = (low + high) / 2 ;
857
+ min_gt_low = high;
858
+ max_le_high = low;
853
859
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
854
860
probs_vec.fill (DType (0 ));
855
861
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -858,26 +864,42 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
858
864
#pragma unroll
859
865
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
860
866
probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : DType (0 );
867
+ if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
868
+ min_gt_low = min (min_gt_low, probs_vec[j]);
869
+ }
870
+ if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
871
+ max_le_high = max (max_le_high, probs_vec[j]);
872
+ }
861
873
}
862
874
threadlocal_sum +=
863
875
BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
864
876
.Sum <VEC_SIZE>(probs_greater_than_pivot);
865
877
__syncthreads ();
866
878
}
879
+ min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
880
+ .Reduce (min_gt_low, cub::Min ());
881
+ __syncthreads ();
882
+ max_le_high =
883
+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
884
+ .Reduce (max_le_high, cub::Max ());
867
885
if (tx == 0 ) {
868
886
temp_storage.data .block_aggregate .value = threadlocal_sum;
887
+ temp_storage.data .min_val = min_gt_low;
888
+ temp_storage.data .max_val = max_le_high;
869
889
}
870
890
__syncthreads ();
871
891
threadlocal_sum = temp_storage.data .block_aggregate .value ;
892
+ min_gt_low = temp_storage.data .min_val ;
893
+ max_le_high = temp_storage.data .max_val ;
872
894
if (threadlocal_sum >= p) {
873
895
low = mid;
874
896
sum_low = float (threadlocal_sum);
875
897
} else {
876
- high = mid;
898
+ high = min ( mid, max_le_high) ;
877
899
}
878
- }
900
+ } while (min_gt_low != max_le_high);
879
901
880
- DType normalizer = math::ptx_rcp (max (sum_low, eps ));
902
+ DType normalizer = math::ptx_rcp (max (sum_low, 1e-8 ));
881
903
882
904
// normalize
883
905
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -898,7 +920,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
898
920
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
899
921
typename DType, typename IdType>
900
922
__global__ void TopKMaskLogitsKernel (DType* logits, DType* masked_logits, IdType* top_k_arr,
901
- uint32_t top_k_val, float eps, uint32_t d) {
923
+ uint32_t top_k_val, uint32_t d) {
902
924
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
903
925
const uint32_t row_idx = bx;
904
926
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
@@ -941,12 +963,20 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
941
963
threadlocal_min_val = temp_storage.data .min_val ;
942
964
943
965
float low = threadlocal_min_val - 1 , high = threadlocal_max_val;
966
+ DType min_gt_low, max_le_high;
944
967
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
945
- // loop invariant: f(low) >= k, f(high) < k
946
- while (high - low > eps) {
968
+ // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
969
+ // loop invariant:
970
+ // - f(low) >= k, f(high) < k
971
+ // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
972
+ // stopping condition: min_gt_low == max_le_high
973
+ // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
974
+ do {
947
975
int threadlocal_count_sum = 0 ;
948
976
int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
949
977
float mid = (low + high) / 2 ;
978
+ min_gt_low = high;
979
+ max_le_high = low;
950
980
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
951
981
logits_vec.fill (DType (0 ));
952
982
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -956,23 +986,41 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
956
986
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
957
987
probs_greater_than_pivot_count[j] =
958
988
logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
989
+ if (logits_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
990
+ min_gt_low = min (min_gt_low, logits_vec[j]);
991
+ }
992
+ if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
993
+ max_le_high = max (max_le_high, logits_vec[j]);
994
+ }
959
995
}
960
996
threadlocal_count_sum +=
961
997
BlockReduce<int , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce_int )
962
998
.Sum <VEC_SIZE>(probs_greater_than_pivot_count);
963
999
__syncthreads ();
964
1000
}
1001
+ min_gt_low =
1002
+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1003
+ .Reduce (min_gt_low, cub::Min ());
1004
+ __syncthreads ();
1005
+ max_le_high =
1006
+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1007
+ .Reduce (max_le_high, cub::Max ());
1008
+ __syncthreads ();
965
1009
if (tx == 0 ) {
966
1010
temp_storage.data .block_aggregate .count = threadlocal_count_sum;
1011
+ temp_storage.data .min_val = min_gt_low;
1012
+ temp_storage.data .max_val = max_le_high;
967
1013
}
968
1014
__syncthreads ();
969
1015
threadlocal_count_sum = temp_storage.data .block_aggregate .count ;
1016
+ min_gt_low = temp_storage.data .min_val ;
1017
+ max_le_high = temp_storage.data .max_val ;
970
1018
if (threadlocal_count_sum >= k) {
971
1019
low = mid;
972
1020
} else {
973
- high = mid;
1021
+ high = min ( mid, max_le_high) ;
974
1022
}
975
- }
1023
+ } while (min_gt_low != max_le_high);
976
1024
pivot = low;
977
1025
}
978
1026
@@ -996,7 +1044,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
996
1044
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
997
1045
typename DType, typename IdType>
998
1046
__global__ void TopKRenormProbKernel (DType* probs, DType* renormed_prob, IdType* top_k_arr,
999
- uint32_t top_k_val, float eps, uint32_t d) {
1047
+ uint32_t top_k_val, uint32_t d) {
1000
1048
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
1001
1049
const uint32_t row_idx = bx;
1002
1050
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
@@ -1033,13 +1081,21 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1033
1081
threadlocal_max_val = temp_storage.data .max_val ;
1034
1082
1035
1083
float low = 0 , high = threadlocal_max_val;
1084
+ DType min_gt_low, max_le_high;
1036
1085
DType sum_low (1 );
1037
1086
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
1038
- // loop invariant: f(low) >= k, f(high) < k
1039
- while (high - low > eps) {
1087
+ // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
1088
+ // loop invariant:
1089
+ // - f(low) >= k, f(high) < k
1090
+ // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
1091
+ // stopping condition: min_gt_low == max_le_high
1092
+ // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
1093
+ do {
1040
1094
Pair<DType> threadlocal_sum{DType (0 ), 0 };
1041
1095
Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
1042
1096
float mid = (low + high) / 2 ;
1097
+ min_gt_low = high;
1098
+ max_le_high = low;
1043
1099
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1044
1100
probs_vec.fill (DType (0 ));
1045
1101
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1050,26 +1106,44 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1050
1106
probs_greater_than_pivot_pair[j] = {
1051
1107
(probs_vec[j] > mid) ? probs_vec[j] : DType (0 ),
1052
1108
(probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1109
+ if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1110
+ min_gt_low = min (min_gt_low, probs_vec[j]);
1111
+ }
1112
+ if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1113
+ max_le_high = max (max_le_high, probs_vec[j]);
1114
+ }
1053
1115
}
1054
1116
threadlocal_sum += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
1055
1117
temp_storage.block_prim .reduce_pair )
1056
1118
.Sum <VEC_SIZE>(probs_greater_than_pivot_pair);
1057
1119
__syncthreads ();
1058
1120
}
1121
+ min_gt_low =
1122
+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1123
+ .Reduce (min_gt_low, cub::Min ());
1124
+ __syncthreads ();
1125
+ max_le_high =
1126
+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1127
+ .Reduce (max_le_high, cub::Max ());
1128
+ __syncthreads ();
1059
1129
if (tx == 0 ) {
1060
1130
temp_storage.data .block_aggregate .pair = threadlocal_sum;
1131
+ temp_storage.data .min_val = min_gt_low;
1132
+ temp_storage.data .max_val = max_le_high;
1061
1133
}
1062
1134
__syncthreads ();
1063
1135
threadlocal_sum = temp_storage.data .block_aggregate .pair ;
1136
+ min_gt_low = temp_storage.data .min_val ;
1137
+ max_le_high = temp_storage.data .max_val ;
1064
1138
if (threadlocal_sum.count >= k) {
1065
1139
low = mid;
1066
1140
sum_low = float (threadlocal_sum.value );
1067
1141
} else {
1068
- high = mid;
1142
+ high = min ( mid, max_le_high) ;
1069
1143
}
1070
- }
1144
+ } while (min_gt_low != max_le_high);
1071
1145
1072
- normalizer = math::ptx_rcp (max (sum_low, eps ));
1146
+ normalizer = math::ptx_rcp (max (sum_low, 1e-8 ));
1073
1147
pivot = low;
1074
1148
}
1075
1149
@@ -1090,7 +1164,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1090
1164
}
1091
1165
1092
1166
template <typename DType>
1093
- cudaError_t TopPRenormProb (DType* probs, DType* renormed_prob, DType* top_p_arr, float eps,
1167
+ cudaError_t TopPRenormProb (DType* probs, DType* renormed_prob, DType* top_p_arr,
1094
1168
uint32_t batch_size, float top_p_val, uint32_t d,
1095
1169
cudaStream_t stream = 0 ) {
1096
1170
const uint32_t BLOCK_THREADS = 1024 ;
@@ -1099,7 +1173,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
1099
1173
const uint32_t smem_size = sizeof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
1100
1174
dim3 nblks (batch_size);
1101
1175
dim3 nthrs (BLOCK_THREADS);
1102
- void * args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &eps, & d};
1176
+ void * args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
1103
1177
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
1104
1178
auto kernel = TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
1105
1179
FLASHINFER_CUDA_CALL (
@@ -1110,7 +1184,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
1110
1184
}
1111
1185
1112
1186
template <typename DType, typename IdType>
1113
- cudaError_t TopKRenormProb (DType* probs, DType* renormed_prob, IdType* top_k_arr, float eps,
1187
+ cudaError_t TopKRenormProb (DType* probs, DType* renormed_prob, IdType* top_k_arr,
1114
1188
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
1115
1189
cudaStream_t stream = 0 ) {
1116
1190
const uint32_t BLOCK_THREADS = 1024 ;
@@ -1119,7 +1193,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
1119
1193
const uint32_t smem_size = sizeof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
1120
1194
dim3 nblks (batch_size);
1121
1195
dim3 nthrs (BLOCK_THREADS);
1122
- void * args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &eps, & d};
1196
+ void * args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
1123
1197
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
1124
1198
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
1125
1199
FLASHINFER_CUDA_CALL (
@@ -1130,7 +1204,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
1130
1204
}
1131
1205
1132
1206
template <typename DType, typename IdType>
1133
- cudaError_t TopKMaskLogits (DType* logits, DType* masked_logits, IdType* top_k_arr, float eps,
1207
+ cudaError_t TopKMaskLogits (DType* logits, DType* masked_logits, IdType* top_k_arr,
1134
1208
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
1135
1209
cudaStream_t stream = 0 ) {
1136
1210
const uint32_t BLOCK_THREADS = 1024 ;
@@ -1139,7 +1213,7 @@ cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_ar
1139
1213
const uint32_t smem_size = sizeof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
1140
1214
dim3 nblks (batch_size);
1141
1215
dim3 nthrs (BLOCK_THREADS);
1142
- void * args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &eps, & d};
1216
+ void * args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
1143
1217
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
1144
1218
auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
1145
1219
FLASHINFER_CUDA_CALL (
0 commit comments