|
25 | 25 | #include <cub/block/block_scan.cuh>
|
26 | 26 | #include <cuda/std/limits>
|
27 | 27 | #include <numeric>
|
| 28 | +#include <tuple> |
28 | 29 |
|
29 | 30 | #include "math.cuh"
|
30 | 31 | #include "utils.cuh"
|
@@ -97,10 +98,10 @@ struct SamplingTempStorage {
|
97 | 98 | struct {
|
98 | 99 | int32_t sampled_id;
|
99 | 100 | int32_t last_valid_id;
|
| 101 | + float max_val; |
100 | 102 | union {
|
101 | 103 | float value;
|
102 | 104 | ValueCount<float> pair;
|
103 |
| - float max_p; |
104 | 105 | } block_aggregate;
|
105 | 106 | };
|
106 | 107 | };
|
@@ -190,6 +191,75 @@ __device__ __forceinline__ void DeterministicInclusiveSum(
|
190 | 191 | }
|
191 | 192 | }
|
192 | 193 |
|
| 194 | +template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, |
| 195 | + typename TempStorage> |
| 196 | +__device__ __forceinline__ std::tuple<float, float> GetMinMaxValue(float* in_data, uint32_t row_idx, |
| 197 | + uint32_t d, |
| 198 | + TempStorage& temp_storage) { |
| 199 | + const uint32_t tx = threadIdx.x; |
| 200 | + vec_t<float, VEC_SIZE> in_data_vec; |
| 201 | + float max_val = -cuda::std::numeric_limits<float>::infinity(), |
| 202 | + min_val = cuda::std::numeric_limits<float>::infinity(); |
| 203 | + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
| 204 | + in_data_vec.fill(0); |
| 205 | + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
| 206 | + in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); |
| 207 | + } |
| 208 | + float in_data_[VEC_SIZE]; |
| 209 | +#pragma unroll |
| 210 | + for (uint32_t j = 0; j < VEC_SIZE; ++j) { |
| 211 | + in_data_[j] = in_data_vec[j]; |
| 212 | + } |
| 213 | + max_val = max( |
| 214 | + max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) |
| 215 | + .Reduce<VEC_SIZE>(in_data_, cub::Max())); |
| 216 | + __syncthreads(); |
| 217 | + min_val = min( |
| 218 | + min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) |
| 219 | + .Reduce<VEC_SIZE>(in_data_, cub::Min())); |
| 220 | + __syncthreads(); |
| 221 | + } |
| 222 | + if (tx == 0) { |
| 223 | + temp_storage.max_val = max_val; |
| 224 | + temp_storage.min_val = min_val; |
| 225 | + } |
| 226 | + __syncthreads(); |
| 227 | + max_val = temp_storage.max_val; |
| 228 | + min_val = temp_storage.min_val; |
| 229 | + |
| 230 | + return std::make_tuple(min_val, max_val); |
| 231 | +} |
| 232 | + |
| 233 | +template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, |
| 234 | + typename TempStorage> |
| 235 | +__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, |
| 236 | + TempStorage& temp_storage) { |
| 237 | + const uint32_t tx = threadIdx.x; |
| 238 | + vec_t<float, VEC_SIZE> in_data_vec; |
| 239 | + |
| 240 | + float max_val = 0; |
| 241 | + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
| 242 | + in_data_vec.fill(0); |
| 243 | + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
| 244 | + in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); |
| 245 | + } |
| 246 | + float in_data_[VEC_SIZE]; |
| 247 | +#pragma unroll |
| 248 | + for (uint32_t j = 0; j < VEC_SIZE; ++j) { |
| 249 | + in_data_[j] = in_data_vec[j]; |
| 250 | + } |
| 251 | + max_val = max( |
| 252 | + max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) |
| 253 | + .Reduce<VEC_SIZE>(in_data_, cub::Max())); |
| 254 | + __syncthreads(); |
| 255 | + } |
| 256 | + if (tx == 0) { |
| 257 | + temp_storage.max_val = max_val; |
| 258 | + } |
| 259 | + __syncthreads(); |
| 260 | + return temp_storage.max_val; |
| 261 | +} |
| 262 | + |
193 | 263 | template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
194 | 264 | BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
|
195 | 265 | __device__ __forceinline__ void DeviceSamplingFromProb(
|
@@ -335,10 +405,14 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
|
335 | 405 | reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
336 | 406 | smem_sampling);
|
337 | 407 |
|
| 408 | + float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, |
| 409 | + SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>( |
| 410 | + probs, row_idx, d, temp_storage); |
| 411 | + |
338 | 412 | vec_t<float, VEC_SIZE> probs_vec;
|
339 | 413 | float aggregate;
|
340 | 414 | float q = 1;
|
341 |
| - double low = 0, high = 1; |
| 415 | + double low = 0, high = max_val; |
342 | 416 | int sampled_id;
|
343 | 417 | int round = 0;
|
344 | 418 | do {
|
@@ -448,10 +522,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
|
448 | 522 | reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
449 | 523 | smem_sampling);
|
450 | 524 |
|
| 525 | + float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, |
| 526 | + SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>( |
| 527 | + probs, row_idx, d, temp_storage); |
| 528 | + |
451 | 529 | vec_t<float, VEC_SIZE> probs_vec;
|
452 | 530 | float aggregate;
|
453 | 531 | float q = 1;
|
454 |
| - double low = 0, high = 1; |
| 532 | + double low = 0, high = max_val; |
455 | 533 | int sampled_id;
|
456 | 534 | do {
|
457 | 535 | temp_storage.sampled_id = d;
|
@@ -552,29 +630,12 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
|
552 | 630 | reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
553 | 631 | smem_sampling);
|
554 | 632 |
|
555 |
| - vec_t<float, VEC_SIZE> probs_vec; |
556 |
| - |
557 |
| - float max_p = 0; |
558 |
| - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
559 |
| - probs_vec.fill(0); |
560 |
| - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
561 |
| - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); |
562 |
| - } |
563 |
| - float probs_[VEC_SIZE]; |
564 |
| -#pragma unroll |
565 |
| - for (uint32_t j = 0; j < VEC_SIZE; ++j) { |
566 |
| - probs_[j] = probs_vec[j]; |
567 |
| - } |
568 |
| - max_p = max(max_p, BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce) |
569 |
| - .Reduce<VEC_SIZE>(probs_, cub::Max())); |
570 |
| - __syncthreads(); |
571 |
| - } |
572 |
| - if (tx == 0) { |
573 |
| - temp_storage.block_aggregate.max_p = max_p; |
574 |
| - } |
575 |
| - __syncthreads(); |
576 |
| - float pivot = temp_storage.block_aggregate.max_p * p; |
| 633 | + float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, |
| 634 | + SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>( |
| 635 | + probs, row_idx, d, temp_storage); |
| 636 | + float pivot = max_val * p; |
577 | 637 |
|
| 638 | + vec_t<float, VEC_SIZE> probs_vec; |
578 | 639 | float aggregate_gt_pivot = 0;
|
579 | 640 | for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
580 | 641 | probs_vec.fill(0);
|
@@ -648,10 +709,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
|
648 | 709 | reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
649 | 710 | smem_sampling);
|
650 | 711 |
|
| 712 | + float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, |
| 713 | + SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>( |
| 714 | + probs, row_idx, d, temp_storage); |
| 715 | + |
651 | 716 | vec_t<float, VEC_SIZE> probs_vec;
|
652 | 717 | float aggregate;
|
653 | 718 | float q = 1;
|
654 |
| - double low = 0, high = 1; |
| 719 | + double low = 0, high = max_val; |
655 | 720 | int sampled_id;
|
656 | 721 | do {
|
657 | 722 | temp_storage.sampled_id = d;
|
@@ -904,29 +969,11 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
|
904 | 969 | vec_t<float, VEC_SIZE> probs_vec;
|
905 | 970 | float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
|
906 | 971 |
|
907 |
| - float threadlocal_max_val = 0; |
908 |
| - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
909 |
| - probs_vec.fill(0); |
910 |
| - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
911 |
| - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); |
912 |
| - } |
913 |
| -#pragma unroll |
914 |
| - for (uint32_t j = 0; j < VEC_SIZE; ++j) { |
915 |
| - probs_greater_than_pivot[j] = probs_vec[j]; |
916 |
| - } |
917 |
| - threadlocal_max_val = |
918 |
| - max(threadlocal_max_val, |
919 |
| - BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) |
920 |
| - .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max())); |
921 |
| - __syncthreads(); |
922 |
| - } |
923 |
| - if (tx == 0) { |
924 |
| - temp_storage.max_val = threadlocal_max_val; |
925 |
| - } |
926 |
| - __syncthreads(); |
927 |
| - threadlocal_max_val = temp_storage.max_val; |
| 972 | + float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, |
| 973 | + RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(probs, row_idx, d, |
| 974 | + temp_storage); |
928 | 975 |
|
929 |
| - double low = 0, high = threadlocal_max_val; |
| 976 | + double low = 0, high = max_val; |
930 | 977 | float min_gt_low, max_le_high;
|
931 | 978 | float sum_low = 1;
|
932 | 979 | // f(x) = sum(probs[probs > x]), f(x) is non-increasing
|
@@ -1019,37 +1066,11 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
|
1019 | 1066 | reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
|
1020 | 1067 | float logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
|
1021 | 1068 |
|
1022 |
| - float threadlocal_max_val = -cuda::std::numeric_limits<float>::infinity(), |
1023 |
| - threadlocal_min_val = cuda::std::numeric_limits<float>::infinity(); |
1024 |
| - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
1025 |
| - logits_vec.fill(0); |
1026 |
| - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
1027 |
| - logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); |
1028 |
| - } |
1029 |
| -#pragma unroll |
1030 |
| - for (uint32_t j = 0; j < VEC_SIZE; ++j) { |
1031 |
| - logits_greater_than_pivot[j] = logits_vec[j]; |
1032 |
| - } |
1033 |
| - threadlocal_max_val = |
1034 |
| - max(threadlocal_max_val, |
1035 |
| - BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) |
1036 |
| - .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Max())); |
1037 |
| - __syncthreads(); |
1038 |
| - threadlocal_min_val = |
1039 |
| - min(threadlocal_min_val, |
1040 |
| - BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) |
1041 |
| - .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Min())); |
1042 |
| - __syncthreads(); |
1043 |
| - } |
1044 |
| - if (tx == 0) { |
1045 |
| - temp_storage.max_val = threadlocal_max_val; |
1046 |
| - temp_storage.min_val = threadlocal_min_val; |
1047 |
| - } |
1048 |
| - __syncthreads(); |
1049 |
| - threadlocal_max_val = temp_storage.max_val; |
1050 |
| - threadlocal_min_val = temp_storage.min_val; |
| 1069 | + auto [min_val, max_val] = GetMinMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, |
| 1070 | + RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>( |
| 1071 | + logits, row_idx, d, temp_storage); |
1051 | 1072 |
|
1052 |
| - double low = threadlocal_min_val - 1, high = threadlocal_max_val; |
| 1073 | + double low = min_val - 1, high = max_val; |
1053 | 1074 | float min_gt_low, max_le_high;
|
1054 | 1075 | // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
1055 | 1076 | // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
|
@@ -1144,29 +1165,11 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
1144 | 1165 | temp_storage.max_val = 0;
|
1145 | 1166 | float probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
|
1146 | 1167 |
|
1147 |
| - float threadlocal_max_val = 0; |
1148 |
| - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { |
1149 |
| - probs_vec.fill(0); |
1150 |
| - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { |
1151 |
| - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); |
1152 |
| - } |
1153 |
| -#pragma unroll |
1154 |
| - for (uint32_t j = 0; j < VEC_SIZE; ++j) { |
1155 |
| - probs_greater_than_pivot[j] = probs_vec[j]; |
1156 |
| - } |
1157 |
| - threadlocal_max_val = |
1158 |
| - max(threadlocal_max_val, |
1159 |
| - BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) |
1160 |
| - .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max())); |
1161 |
| - __syncthreads(); |
1162 |
| - } |
1163 |
| - if (tx == 0) { |
1164 |
| - temp_storage.max_val = threadlocal_max_val; |
1165 |
| - } |
1166 |
| - __syncthreads(); |
1167 |
| - threadlocal_max_val = temp_storage.max_val; |
| 1168 | + float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, |
| 1169 | + RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>( |
| 1170 | + probs, row_idx, d, temp_storage); |
1168 | 1171 |
|
1169 |
| - double low = 0, high = threadlocal_max_val; |
| 1172 | + double low = 0, high = max_val; |
1170 | 1173 | float min_gt_low, max_le_high;
|
1171 | 1174 | float sum_low = 1;
|
1172 | 1175 | // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
|
0 commit comments