Skip to content

Commit 0f80329

Browse files
authored
sampling: simplify min-p sampling (#713)
As suggested by @aikitoria and @rolandtannous in #710 , we don't need rejection algorithm for min-p sampling, this PR simplifies the design. There is a breaking change on API: we no longer returns `success` array for min-p sampling because there is no risk of rejecting the samples. In later PRs, we will also remove the `success` array in top-p/top-k sampling APIs.
1 parent 561f646 commit 0f80329

File tree

5 files changed

+83
-119
lines changed

5 files changed

+83
-119
lines changed

Diff for: csrc/flashinfer_sampling_ops.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
2727
unsigned int top_k_val, bool deterministic, int64_t cuda_stream);
2828

2929
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
30-
at::Tensor success, std::optional<at::Tensor> maybe_min_p_arr,
31-
double min_p_val, bool deterministic, int64_t cuda_stream);
30+
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
31+
bool deterministic, int64_t cuda_stream);
3232

3333
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples,
3434
at::Tensor samples, at::Tensor success,

Diff for: csrc/sampling.cu

+6-7
Original file line numberDiff line numberDiff line change
@@ -90,26 +90,25 @@ void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
9090
}
9191

9292
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
93-
at::Tensor success, std::optional<at::Tensor> maybe_min_p_arr,
94-
double min_p_val, bool deterministic, int64_t cuda_stream) {
93+
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
94+
bool deterministic, int64_t cuda_stream) {
9595
CHECK_INPUT(probs);
9696
CHECK_INPUT(uniform_samples);
9797
auto device = probs.device();
9898
CHECK_EQ(uniform_samples.device(), device);
9999
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
100-
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
100+
CHECK_DIM(1, uniform_samples); // uniform_samples: (batch_size)
101101
unsigned int batch_size = probs.size(0);
102102
unsigned int vocab_size = probs.size(1);
103-
unsigned int max_rounds = uniform_samples.size(0);
104-
CHECK_EQ(uniform_samples.size(1), batch_size);
103+
CHECK_EQ(uniform_samples.size(0), batch_size);
105104
bool has_min_p_arr = maybe_min_p_arr.has_value();
106105

107106
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
108107
cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
109108
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
110109
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr->data_ptr()) : nullptr,
111-
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
112-
min_p_val, vocab_size, max_rounds, deterministic, stream);
110+
static_cast<int*>(samples.data_ptr()), batch_size, min_p_val, vocab_size, deterministic,
111+
stream);
113112
TORCH_CHECK(status == cudaSuccess, "MinPSamplingFromProb failed with error code " +
114113
std::string(cudaGetErrorString(status)));
115114
}

Diff for: flashinfer/sampling.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -164,26 +164,24 @@ def min_p_sampling_from_probs(
164164
maybe_min_p_arr: Optional[torch.Tensor],
165165
min_p_val: float,
166166
deterministic: bool,
167-
) -> Tuple[torch.Tensor, torch.Tensor]:
167+
) -> torch.Tensor:
168168
with probs.device as device:
169169
probs = probs.float()
170170
uniform_samples = uniform_samples.float()
171171
maybe_min_p_arr = (
172172
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
173173
)
174174
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
175-
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
176175
module.min_p_sampling_from_probs(
177176
probs,
178177
uniform_samples,
179178
samples,
180-
success,
181179
maybe_min_p_arr,
182180
min_p_val,
183181
deterministic,
184182
get_cuda_stream(device),
185183
)
186-
return samples, success
184+
return samples
187185

188186
# torch library for top_k_top_p_sampling_from_probs
189187

@@ -634,7 +632,7 @@ def min_p_sampling_from_probs(
634632
min_p: Union[torch.Tensor, float],
635633
deterministic: bool = True,
636634
check_nan: bool = False,
637-
) -> Tuple[torch.Tensor, torch.Tensor]:
635+
) -> torch.Tensor:
638636
r"""Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
639637
640638
this operator implements GPU-based rejection sampling without explicit sorting.
@@ -647,8 +645,7 @@ def min_p_sampling_from_probs(
647645
probs: torch.Tensor
648646
Probabilities, shape ``(batch_size, num_classes)``.
649647
uniform_samples: torch.Tensor
650-
The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``,
651-
where the first dimension is the maximum number of rounds for rejection sampling.
648+
The uniform samples used as needle for sampling, shape ``(batch_size,)``,
652649
Expected to be uniformly distributed in ``[0, 1)``.
653650
min_p: torch.Tensor
654651
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
@@ -663,9 +660,6 @@ def min_p_sampling_from_probs(
663660
-------
664661
samples: torch.Tensor
665662
Sampled categories, shape ``(batch_size,)``.
666-
success: torch.Tensor
667-
Whether the sampling is successful within ``max_top_k_rounds`` rounds,
668-
shape ``(batch_size,)``.
669663
670664
Examples
671665
--------
@@ -676,7 +670,6 @@ def min_p_sampling_from_probs(
676670
<torch._C.Generator object at 0x7f8b3db06df0>
677671
>>> batch_size = 4
678672
>>> vocab_size = 5
679-
>>> max_rounds = 3
680673
>>> min_p = torch.full((batch_size,), 0.05).to(0)
681674
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
682675
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
@@ -685,19 +678,22 @@ def min_p_sampling_from_probs(
685678
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
686679
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
687680
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
688-
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
689-
>>> samples, success = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, uniform_samples, min_p)
681+
>>> uniform_samples = torch.rand(batch_size).to(0)
682+
>>> samples = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, uniform_samples, min_p)
690683
>>> samples
691684
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
692-
>>> success
693-
tensor([True, True, True, True], device='cuda:0')
694685
695686
Note
696687
----
697688
This function expects float32 inputs, and the output is int32.
698689
We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual
699690
implementation usually use much fewer rounds for rejection sampling because of early stopping.
700691
"""
692+
# NOTE(Zihao): for backward compatiblity (https://github.com/flashinfer-ai/flashinfer/pull/713)
693+
if uniform_samples.dim() == 2:
694+
# Take the first row (round) of uniform_samples
695+
uniform_samples = uniform_samples[0]
696+
701697
if check_nan:
702698
if torch.any(torch.isnan(probs)):
703699
raise ValueError("Input probs contains NaN.")

Diff for: include/flashinfer/sampling.cuh

+56-84
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,18 @@ __device__ __forceinline__ void DeterministicInclusiveSum(
184184
}
185185

186186
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
187-
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T>
187+
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T, typename Predicate>
188188
__device__ __forceinline__ void DeviceSamplingFromProb(
189-
uint32_t i, uint32_t d, T threshold, T u, vec_t<T, VEC_SIZE> prob_vec, T& aggregate,
189+
uint32_t i, uint32_t d, Predicate pred, T u, vec_t<T, VEC_SIZE> prob_vec, T& aggregate,
190190
SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
191191
const uint32_t tx = threadIdx.x;
192192
T prob_greater_than_threshold[VEC_SIZE];
193193
T inclusive_cdf[VEC_SIZE];
194194
bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
195195
#pragma unroll
196196
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
197-
prob_greater_than_threshold[j] = (prob_vec[j] > threshold) ? prob_vec[j] : T(0);
198-
valid[j] = prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
197+
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : T(0);
198+
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
199199
}
200200
T aggregate_local =
201201
BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
@@ -219,7 +219,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
219219

220220
#pragma unroll
221221
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
222-
greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
222+
greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j];
223223
}
224224

225225
bool greater_than_u_diff[VEC_SIZE];
@@ -234,13 +234,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
234234

235235
#pragma unroll
236236
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
237-
if (greater_than_u_diff[j] && valid[j]) {
238-
if constexpr (DETERMINISTIC) {
239-
temp_storage->sampled_id = (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
240-
} else {
241-
// cub's block scan result might not be monotonic, so we need to find the first element
242-
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
243-
}
237+
if (greater_than_u_diff[j]) {
238+
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
244239
}
245240
}
246241
__syncthreads();
@@ -275,7 +270,8 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
275270
}
276271

277272
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
278-
DType>(i, d, DType(0), u, probs_vec, aggregate, &temp_storage);
273+
DType>(
274+
i, d, [](DType x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage);
279275
if (float(aggregate) > u) {
280276
break;
281277
}
@@ -316,8 +312,8 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
316312
}
317313

318314
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
319-
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
320-
&temp_storage);
315+
DETERMINISTIC, DType>(
316+
i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
321317
if (aggregate > u) {
322318
break;
323319
}
@@ -404,8 +400,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
404400
}
405401

406402
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
407-
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
408-
&temp_storage);
403+
DETERMINISTIC, DType>(
404+
i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
409405
if (aggregate > u) {
410406
break;
411407
}
@@ -459,8 +455,7 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
459455
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
460456
typename DType, typename IdType>
461457
__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType* min_p_arr,
462-
IdType* output, bool* success, float min_p_val,
463-
uint32_t d, uint32_t max_min_p_rounds) {
458+
IdType* output, float min_p_val, uint32_t d) {
464459
const uint32_t batch_size = gridDim.x;
465460
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
466461
DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx];
@@ -472,9 +467,6 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
472467
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
473468

474469
vec_t<DType, VEC_SIZE> probs_vec;
475-
DType aggregate;
476-
DType q = DType(1);
477-
DType pivot = DType(0);
478470

479471
DType max_p = 0;
480472
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -495,70 +487,50 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
495487
temp_storage.block_aggregate.max_p = max_p;
496488
}
497489
__syncthreads();
498-
DType scaled_p = temp_storage.block_aggregate.max_p * p;
490+
DType pivot = temp_storage.block_aggregate.max_p * p;
499491

500-
IdType sampled_id;
501-
for (uint32_t round = 0; round < max_min_p_rounds; ++round) {
502-
temp_storage.sampled_id = d - 1;
503-
__syncthreads();
504-
DType u = uniform_samples[round * batch_size + bx] * q;
505-
aggregate = DType(0);
506-
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
507-
probs_vec.fill(DType(0));
508-
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
509-
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
510-
}
511-
512-
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
513-
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
514-
&temp_storage);
515-
if (aggregate > u) {
516-
break;
517-
}
518-
}
519-
__syncthreads();
520-
sampled_id = temp_storage.sampled_id;
521-
pivot = max(pivot, probs[bx * d + sampled_id]);
522-
if (pivot >= scaled_p) {
523-
break;
492+
DType aggregate_gt_pivot = DType(0);
493+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
494+
probs_vec.fill(DType(0));
495+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
496+
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
524497
}
525498

526-
DType aggregate_gt_pivot = DType(0);
527-
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
528-
probs_vec.fill(DType(0));
529-
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
530-
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
531-
}
532-
533-
DType probs_gt_pivot[VEC_SIZE];
499+
DType probs_gt_pivot[VEC_SIZE];
534500
#pragma unroll
535-
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
536-
probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
537-
}
501+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
502+
probs_gt_pivot[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : DType(0);
503+
}
538504

539-
aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
540-
.Sum<VEC_SIZE>(probs_gt_pivot);
541-
if (tx == 0) {
542-
temp_storage.block_aggregate.value = aggregate_gt_pivot;
543-
}
544-
__syncthreads();
505+
aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
506+
.Sum<VEC_SIZE>(probs_gt_pivot);
507+
if (tx == 0) {
508+
temp_storage.block_aggregate.value = aggregate_gt_pivot;
545509
}
546-
q = temp_storage.block_aggregate.value;
510+
__syncthreads();
547511
}
512+
513+
DType aggregate(0);
514+
DType q = temp_storage.block_aggregate.value;
515+
516+
IdType sampled_id;
517+
temp_storage.sampled_id = d - 1;
548518
__syncthreads();
549-
if (tx == 0) {
550-
output[bx] = sampled_id;
551-
if (pivot < scaled_p) {
552-
// failed to sample within MAX_ROUNDS
553-
if (success != nullptr) {
554-
success[bx] = false;
555-
}
556-
} else {
557-
if (success != nullptr) {
558-
success[bx] = true;
559-
}
519+
DType u = uniform_samples[bx] * q;
520+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
521+
probs_vec.fill(DType(0));
522+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
523+
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
524+
}
525+
526+
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
527+
DType>(
528+
i, d, [&](DType x) { return x >= pivot; }, u, probs_vec, aggregate, &temp_storage);
529+
if (aggregate > u) {
530+
break;
560531
}
561532
}
533+
output[bx] = temp_storage.sampled_id;
562534
}
563535

564536
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
@@ -596,8 +568,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
596568
}
597569

598570
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
599-
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
600-
&temp_storage);
571+
DETERMINISTIC, DType>(
572+
i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
601573
if (aggregate > u) {
602574
break;
603575
}
@@ -749,16 +721,15 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
749721

750722
template <typename T, typename IdType>
751723
cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, IdType* output,
752-
bool* success, uint32_t batch_size, float min_p_val, uint32_t d,
753-
uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) {
724+
uint32_t batch_size, float min_p_val, uint32_t d,
725+
bool deterministic, cudaStream_t stream = 0) {
754726
constexpr uint32_t BLOCK_THREADS = 1024;
755727
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
756728

757729
const uint32_t smem_size = sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
758730
dim3 nblks(batch_size);
759731
dim3 nthrs(BLOCK_THREADS);
760-
void* args[] = {&probs, &uniform_samples, &min_p_arr, &output,
761-
&success, &min_p_val, &d, &max_rounds};
732+
void* args[] = {&probs, &uniform_samples, &min_p_arr, &output, &min_p_val, &d};
762733

763734
DISPATCH_ALIGNED_VEC_SIZE(
764735
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
@@ -1350,8 +1321,9 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
13501321
}
13511322

13521323
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
1353-
DType>(i, d, DType(0), u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
1354-
&temp_storage);
1324+
DType>(
1325+
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
1326+
&temp_storage);
13551327
if (aggregate_relu_q_minus_p > u) {
13561328
break;
13571329
}

0 commit comments

Comments
 (0)