@@ -184,18 +184,18 @@ __device__ __forceinline__ void DeterministicInclusiveSum(
184
184
}
185
185
186
186
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
187
- BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T>
187
+ BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T, typename Predicate >
188
188
__device__ __forceinline__ void DeviceSamplingFromProb (
189
- uint32_t i, uint32_t d, T threshold , T u, vec_t <T, VEC_SIZE> prob_vec, T& aggregate,
189
+ uint32_t i, uint32_t d, Predicate pred , T u, vec_t <T, VEC_SIZE> prob_vec, T& aggregate,
190
190
SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
191
191
const uint32_t tx = threadIdx .x ;
192
192
T prob_greater_than_threshold[VEC_SIZE];
193
193
T inclusive_cdf[VEC_SIZE];
194
194
bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
195
195
#pragma unroll
196
196
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
197
- prob_greater_than_threshold[j] = (prob_vec[j] > threshold ) ? prob_vec[j] : T (0 );
198
- valid[j] = prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
197
+ prob_greater_than_threshold[j] = pred (prob_vec[j]) ? prob_vec[j] : T (0 );
198
+ valid[j] = pred ( prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
199
199
}
200
200
T aggregate_local =
201
201
BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim .reduce )
@@ -219,7 +219,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
219
219
220
220
#pragma unroll
221
221
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
222
- greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
222
+ greater_than_u[j] = ( inclusive_cdf[j] + aggregate > u) && valid[j] ;
223
223
}
224
224
225
225
bool greater_than_u_diff[VEC_SIZE];
@@ -234,13 +234,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
234
234
235
235
#pragma unroll
236
236
for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
237
- if (greater_than_u_diff[j] && valid[j]) {
238
- if constexpr (DETERMINISTIC) {
239
- temp_storage->sampled_id = (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
240
- } else {
241
- // cub's block scan result might not be monotonic, so we need to find the first element
242
- atomicMin (&(temp_storage->sampled_id ), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
243
- }
237
+ if (greater_than_u_diff[j]) {
238
+ atomicMin (&(temp_storage->sampled_id ), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
244
239
}
245
240
}
246
241
__syncthreads ();
@@ -275,7 +270,8 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
275
270
}
276
271
277
272
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
278
- DType>(i, d, DType (0 ), u, probs_vec, aggregate, &temp_storage);
273
+ DType>(
274
+ i, d, [](DType x) { return x > 0 ; }, u, probs_vec, aggregate, &temp_storage);
279
275
if (float (aggregate) > u) {
280
276
break ;
281
277
}
@@ -316,8 +312,8 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
316
312
}
317
313
318
314
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
319
- DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
320
- &temp_storage);
315
+ DETERMINISTIC, DType>(
316
+ i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
321
317
if (aggregate > u) {
322
318
break ;
323
319
}
@@ -404,8 +400,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
404
400
}
405
401
406
402
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
407
- DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
408
- &temp_storage);
403
+ DETERMINISTIC, DType>(
404
+ i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
409
405
if (aggregate > u) {
410
406
break ;
411
407
}
@@ -459,8 +455,7 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
459
455
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
460
456
typename DType, typename IdType>
461
457
__global__ void MinPSamplingFromProbKernel (DType* probs, DType* uniform_samples, DType* min_p_arr,
462
- IdType* output, bool * success, float min_p_val,
463
- uint32_t d, uint32_t max_min_p_rounds) {
458
+ IdType* output, float min_p_val, uint32_t d) {
464
459
const uint32_t batch_size = gridDim .x ;
465
460
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
466
461
DType p = (min_p_arr == nullptr ) ? min_p_val : min_p_arr[bx];
@@ -472,9 +467,6 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
472
467
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
473
468
474
469
vec_t <DType, VEC_SIZE> probs_vec;
475
- DType aggregate;
476
- DType q = DType (1 );
477
- DType pivot = DType (0 );
478
470
479
471
DType max_p = 0 ;
480
472
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -495,70 +487,50 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
495
487
temp_storage.block_aggregate .max_p = max_p;
496
488
}
497
489
__syncthreads ();
498
- DType scaled_p = temp_storage.block_aggregate .max_p * p;
490
+ DType pivot = temp_storage.block_aggregate .max_p * p;
499
491
500
- IdType sampled_id;
501
- for (uint32_t round = 0 ; round < max_min_p_rounds; ++round ) {
502
- temp_storage.sampled_id = d - 1 ;
503
- __syncthreads ();
504
- DType u = uniform_samples[round * batch_size + bx] * q;
505
- aggregate = DType (0 );
506
- for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
507
- probs_vec.fill (DType (0 ));
508
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
509
- probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
510
- }
511
-
512
- DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
513
- DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
514
- &temp_storage);
515
- if (aggregate > u) {
516
- break ;
517
- }
518
- }
519
- __syncthreads ();
520
- sampled_id = temp_storage.sampled_id ;
521
- pivot = max (pivot, probs[bx * d + sampled_id]);
522
- if (pivot >= scaled_p) {
523
- break ;
492
+ DType aggregate_gt_pivot = DType (0 );
493
+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
494
+ probs_vec.fill (DType (0 ));
495
+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
496
+ probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
524
497
}
525
498
526
- DType aggregate_gt_pivot = DType (0 );
527
- for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
528
- probs_vec.fill (DType (0 ));
529
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
530
- probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
531
- }
532
-
533
- DType probs_gt_pivot[VEC_SIZE];
499
+ DType probs_gt_pivot[VEC_SIZE];
534
500
#pragma unroll
535
- for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
536
- probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType (0 );
537
- }
501
+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
502
+ probs_gt_pivot[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : DType (0 );
503
+ }
538
504
539
- aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim .reduce )
540
- .Sum <VEC_SIZE>(probs_gt_pivot);
541
- if (tx == 0 ) {
542
- temp_storage.block_aggregate .value = aggregate_gt_pivot;
543
- }
544
- __syncthreads ();
505
+ aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim .reduce )
506
+ .Sum <VEC_SIZE>(probs_gt_pivot);
507
+ if (tx == 0 ) {
508
+ temp_storage.block_aggregate .value = aggregate_gt_pivot;
545
509
}
546
- q = temp_storage. block_aggregate . value ;
510
+ __syncthreads () ;
547
511
}
512
+
513
+ DType aggregate (0 );
514
+ DType q = temp_storage.block_aggregate .value ;
515
+
516
+ IdType sampled_id;
517
+ temp_storage.sampled_id = d - 1 ;
548
518
__syncthreads ();
549
- if (tx == 0 ) {
550
- output[bx] = sampled_id;
551
- if (pivot < scaled_p) {
552
- // failed to sample within MAX_ROUNDS
553
- if (success != nullptr ) {
554
- success[bx] = false ;
555
- }
556
- } else {
557
- if (success != nullptr ) {
558
- success[bx] = true ;
559
- }
519
+ DType u = uniform_samples[bx] * q;
520
+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
521
+ probs_vec.fill (DType (0 ));
522
+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
523
+ probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
524
+ }
525
+
526
+ DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
527
+ DType>(
528
+ i, d, [&](DType x) { return x >= pivot; }, u, probs_vec, aggregate, &temp_storage);
529
+ if (aggregate > u) {
530
+ break ;
560
531
}
561
532
}
533
+ output[bx] = temp_storage.sampled_id ;
562
534
}
563
535
564
536
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
@@ -596,8 +568,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
596
568
}
597
569
598
570
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
599
- DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
600
- &temp_storage);
571
+ DETERMINISTIC, DType>(
572
+ i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
601
573
if (aggregate > u) {
602
574
break ;
603
575
}
@@ -749,16 +721,15 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
749
721
750
722
template <typename T, typename IdType>
751
723
cudaError_t MinPSamplingFromProb (T* probs, T* uniform_samples, T* min_p_arr, IdType* output,
752
- bool * success, uint32_t batch_size, float min_p_val, uint32_t d,
753
- uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0 ) {
724
+ uint32_t batch_size, float min_p_val, uint32_t d,
725
+ bool deterministic, cudaStream_t stream = 0 ) {
754
726
constexpr uint32_t BLOCK_THREADS = 1024 ;
755
727
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
756
728
757
729
const uint32_t smem_size = sizeof (SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
758
730
dim3 nblks (batch_size);
759
731
dim3 nthrs (BLOCK_THREADS);
760
- void * args[] = {&probs, &uniform_samples, &min_p_arr, &output,
761
- &success, &min_p_val, &d, &max_rounds};
732
+ void * args[] = {&probs, &uniform_samples, &min_p_arr, &output, &min_p_val, &d};
762
733
763
734
DISPATCH_ALIGNED_VEC_SIZE (
764
735
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC (deterministic, DETERMINISTIC, {
@@ -1350,8 +1321,9 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
1350
1321
}
1351
1322
1352
1323
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
1353
- DType>(i, d, DType (0 ), u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
1354
- &temp_storage);
1324
+ DType>(
1325
+ i, d, [&](DType x) { return x > 0 ; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
1326
+ &temp_storage);
1355
1327
if (aggregate_relu_q_minus_p > u) {
1356
1328
break ;
1357
1329
}
0 commit comments