Skip to content

Commit 6e028eb

Browse files
authored
feat: Fused GPU sampling kernel for joint top-k & top-p sampling (#374)
Currently our sampling kernels only support either top-k or top-p sampling. However, these two sampling algorithms can be used together, this PR implements the sampling kernel that performs top-k and top-p sampling jointly.
1 parent e14fa81 commit 6e028eb

File tree

8 files changed

+325
-16
lines changed

8 files changed

+325
-16
lines changed

include/flashinfer/attention/cascade.cuh

+4-6
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,11 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__
8989
template <uint32_t vec_size, typename DType>
9090
__global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s,
9191
DType* __restrict__ v_other, float* __restrict__ s_other,
92-
uint8_t* __restrict__ mask,
93-
uint32_t num_heads, uint32_t head_dim) {
92+
uint8_t* __restrict__ mask, uint32_t num_heads,
93+
uint32_t head_dim) {
9494
uint32_t pos = blockIdx.x;
9595

96-
if (mask != nullptr && mask[pos] == 0)
97-
return;
96+
if (mask != nullptr && mask[pos] == 0) return;
9897

9998
uint32_t tx = threadIdx.x, ty = threadIdx.y;
10099
uint32_t head_idx = ty;
@@ -396,8 +395,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType
396395
*/
397396
template <typename DType>
398397
cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len,
399-
uint32_t num_heads, uint32_t head_dim,
400-
uint8_t* mask = nullptr,
398+
uint32_t num_heads, uint32_t head_dim, uint8_t* mask = nullptr,
401399
cudaStream_t stream = nullptr) {
402400
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
403401
constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U);

include/flashinfer/sampling.cuh

+114
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#ifndef FLASHINFER_SAMPLING_CUH_
1717
#define FLASHINFER_SAMPLING_CUH_
1818

19+
#include <driver_types.h>
20+
1921
#include <cub/block/block_adjacent_difference.cuh>
2022
#include <cub/block/block_reduce.cuh>
2123
#include <cub/block/block_scan.cuh>
@@ -342,6 +344,96 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
342344
}
343345
}
344346

347+
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
348+
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, typename DType, typename IdType>
349+
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* top_k,
350+
DType* top_p, IdType* output, bool* success,
351+
uint32_t d, uint32_t max_rounds) {
352+
const uint32_t batch_size = gridDim.x;
353+
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
354+
IdType k = top_k[bx];
355+
DType p = top_p[bx];
356+
357+
extern __shared__ __align__(
358+
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
359+
uint8_t smem_sampling[];
360+
auto& temp_storage = reinterpret_cast<
361+
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
362+
363+
vec_t<DType, VEC_SIZE> probs_vec;
364+
DType aggregate;
365+
DType q = DType(0);
366+
DType pivot = DType(0);
367+
IdType sampled_id;
368+
for (uint32_t round = 0; round < max_rounds; ++round) {
369+
temp_storage.data.sampled_id = d - 1;
370+
__syncthreads();
371+
DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q);
372+
aggregate = DType(0);
373+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
374+
probs_vec.fill(DType(0));
375+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
376+
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
377+
}
378+
379+
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DType>(
380+
i, d, pivot, u, probs_vec, aggregate, &temp_storage);
381+
if (aggregate > u) {
382+
break;
383+
}
384+
}
385+
__syncthreads();
386+
sampled_id = temp_storage.data.sampled_id;
387+
pivot = probs[bx * d + sampled_id];
388+
389+
Pair<DType> aggregate_leq_pivot{DType(0), 0};
390+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
391+
probs_vec.fill(DType(0));
392+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
393+
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
394+
}
395+
396+
Pair<DType> probs_leq_pivot[VEC_SIZE];
397+
#pragma unroll
398+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
399+
probs_leq_pivot[j] = {
400+
(probs_vec[j] <= pivot) ? probs_vec[j] : DType(0),
401+
(probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
402+
}
403+
404+
aggregate_leq_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
405+
temp_storage.block_prim.reduce_pair)
406+
.Sum<VEC_SIZE>(probs_leq_pivot);
407+
if (tx == 0) {
408+
temp_storage.data.block_aggregate.pair = aggregate_leq_pivot;
409+
}
410+
__syncthreads();
411+
if (temp_storage.data.block_aggregate.pair.count + k > d &&
412+
float(temp_storage.data.block_aggregate.pair.value) + p > 1 + eps) {
413+
break;
414+
}
415+
}
416+
q = temp_storage.data.block_aggregate.pair.value;
417+
if (temp_storage.data.block_aggregate.pair.count + k > d && float(q) + p > 1 + eps) {
418+
break;
419+
}
420+
}
421+
__syncthreads();
422+
if (tx == 0) {
423+
if (temp_storage.data.block_aggregate.pair.count + k <= d || float(q) + p <= 1 + eps) {
424+
// failed to sample within MAX_TOP_P_ROUNDS
425+
if (success != nullptr) {
426+
success[bx] = false;
427+
}
428+
} else {
429+
output[bx] = sampled_id;
430+
if (success != nullptr) {
431+
success[bx] = true;
432+
}
433+
}
434+
}
435+
}
436+
345437
template <typename T, typename IdType>
346438
cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size,
347439
uint32_t d, cudaStream_t stream = 0) {
@@ -434,6 +526,28 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
434526
return cudaSuccess;
435527
}
436528

529+
template <typename T, typename IdType>
530+
cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k, T* top_p,
531+
IdType* output, bool* success, uint32_t batch_size, uint32_t d,
532+
uint32_t max_rounds, cudaStream_t stream = 0) {
533+
constexpr uint32_t BLOCK_THREADS = 1024;
534+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
535+
536+
const uint32_t smem_size = sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
537+
dim3 nblks(batch_size);
538+
dim3 nthrs(BLOCK_THREADS);
539+
void* args[] = {&probs, &uniform_samples, &top_k, &top_p, &output, &success, &d, &max_rounds};
540+
541+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
542+
auto kernel =
543+
TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, T, IdType>;
544+
FLASHINFER_CUDA_CALL(
545+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
546+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
547+
});
548+
return cudaSuccess;
549+
}
550+
437551
template <typename T, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
438552
struct RenormTempStorage {
439553
union {

python/csrc/flashinfer_ops.cu

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3535
"Top-k sampling from probabilities");
3636
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
3737
"Top-p sampling from probabilities");
38+
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs,
39+
"Top-k and top-p sampling from probabilities");
3840
m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask");
3941
m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask");
4042
m.def("chain_speculative_sampling", &chain_speculative_sampling,

python/csrc/flashinfer_ops.h

+5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
5959
torch::Tensor uniform_samples,
6060
unsigned int top_k);
6161

62+
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
63+
torch::Tensor uniform_samples,
64+
torch::Tensor top_k,
65+
torch::Tensor top_p);
66+
6267
torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);
6368

6469
torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);

python/csrc/sampling.cu

+42
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,48 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
103103
return {samples, success};
104104
}
105105

106+
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
107+
torch::Tensor uniform_samples,
108+
torch::Tensor top_k,
109+
torch::Tensor top_p) {
110+
CHECK_INPUT(probs);
111+
CHECK_INPUT(uniform_samples);
112+
CHECK_INPUT(top_k);
113+
CHECK_INPUT(top_p);
114+
auto device = probs.device();
115+
CHECK_EQ(uniform_samples.device(), device);
116+
CHECK_EQ(top_k.device(), device);
117+
CHECK_EQ(top_p.device(), device);
118+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
119+
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
120+
CHECK_DIM(1, top_k); // top_k: (batch_size,)
121+
CHECK_DIM(1, top_p); // top_p: (batch_size,)
122+
unsigned int batch_size = probs.size(0);
123+
unsigned int vocab_size = probs.size(1);
124+
unsigned int max_rounds = uniform_samples.size(0);
125+
CHECK_EQ(uniform_samples.size(1), batch_size);
126+
CHECK_EQ(top_k.size(0), batch_size);
127+
CHECK_EQ(top_p.size(0), batch_size);
128+
probs = probs.to(torch::kFloat32);
129+
uniform_samples = uniform_samples.to(torch::kFloat32);
130+
top_k = top_k.to(torch::kInt32);
131+
top_p = top_p.to(torch::kFloat32);
132+
133+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
134+
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
135+
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
136+
137+
cudaError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
138+
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
139+
static_cast<int*>(top_k.data_ptr()), static_cast<float*>(top_p.data_ptr()),
140+
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
141+
vocab_size, max_rounds, torch_current_stream);
142+
TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " +
143+
std::string(cudaGetErrorString(status)));
144+
145+
return {samples, success};
146+
}
147+
106148
torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) {
107149
CHECK_INPUT(probs);
108150
auto device = probs.device();

python/flashinfer/sampling.py

+71
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,77 @@ def top_k_sampling_from_probs(
196196
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)
197197

198198

199+
def top_k_top_p_sampling_from_probs(
200+
probs: torch.Tensor,
201+
uniform_samples: torch.Tensor,
202+
top_k: torch.Tensor,
203+
top_p: torch.Tensor,
204+
):
205+
r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities,
206+
207+
this operator implements GPU-based rejection sampling without explicit sorting.
208+
209+
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
210+
which is more efficient than the naive implementation that launches a series of kernels.
211+
212+
Parameters
213+
----------
214+
probs: torch.Tensor
215+
Probabilities, shape ``(batch_size, num_classes)``.
216+
uniform_samples: torch.Tensor
217+
The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``,
218+
where the first dimension is the maximum number of rounds for rejection sampling.
219+
Expected to be uniformly distributed in ``[0, 1)``.
220+
top_k: torch.Tensor
221+
The k in "top-k" for each request, shape ``(batch_size,)``.
222+
top_p: torch.Tensor
223+
The threshold for top-p sampling for each request, shape ``(batch_size,)``.
224+
225+
Returns
226+
-------
227+
(samples, success): Tuple[torch.Tensor, torch.Tensor]
228+
samples: torch.Tensor
229+
Sampled categories, shape ``(batch_size,)``.
230+
success: torch.Tensor
231+
Whether the sampling is successful within ``max_top_k_rounds`` rounds,
232+
shape ``(batch_size,)``.
233+
234+
Examples
235+
--------
236+
237+
>>> import torch
238+
>>> import flashinfer
239+
>>> torch.manual_seed(42)
240+
>>> batch_size = 4
241+
>>> vocab_size = 5
242+
>>> max_rounds = 3
243+
>>> top_p = torch.full((batch_size,), 0.2).to(0)
244+
>>> top_k = torch.full((batch_size,), 2).to(0)
245+
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
246+
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
247+
>>> norm_prob
248+
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
249+
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
250+
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
251+
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
252+
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
253+
>>> samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs(norm_prob, uniform_samples, top_k, top_p)
254+
>>> samples
255+
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)
256+
>>> success
257+
tensor([True, True, True, True], device='cuda:0')
258+
259+
Notes
260+
-----
261+
This function expects float32 inputs, and the output is int32.
262+
We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual
263+
implementation usually use much fewer rounds for rejection sampling because of early stopping.
264+
"""
265+
return _kernels.top_k_top_p_sampling_from_probs(
266+
probs, uniform_samples, top_k, top_p
267+
)
268+
269+
199270
def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5):
200271
r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding.
201272

python/tests/test_sampling.py

+44
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,50 @@ def test_top_k_sampling(batch_size, vocab_size, k):
9595
]
9696

9797

98+
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
99+
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
100+
@pytest.mark.parametrize("p", [0.1, 0.5])
101+
def test_top_k_top_p_sampling(batch_size, vocab_size, p):
102+
if p == 0.1:
103+
k = int(vocab_size * 0.5)
104+
elif p == 0.5:
105+
k = int(vocab_size * 0.1)
106+
else:
107+
raise ValueError("p not recognized")
108+
max_top_k_trails = 32
109+
eps = 1e-4
110+
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
111+
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
112+
# top-p mask
113+
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
114+
cdf = torch.cumsum(sorted_prob, dim=-1)
115+
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
116+
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
117+
# top-k mask
118+
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
119+
pivot = sorted_prob[:, k - 1]
120+
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
121+
# overall mask
122+
mask = torch.minimum(mask_top_p, mask_top_k)
123+
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
124+
0
125+
)
126+
top_p_tensor = torch.full((batch_size,), p).to(0)
127+
top_k_tensor = torch.full((batch_size,), k).to(0)
128+
129+
num_trails = 1000
130+
for _ in range(num_trails):
131+
uniform_samples.uniform_()
132+
samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs(
133+
normalized_prob, uniform_samples, top_k_tensor, top_p_tensor
134+
)
135+
assert torch.all(success)
136+
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
137+
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
138+
torch.arange(batch_size), samples
139+
]
140+
141+
98142
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
99143
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
100144
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])

0 commit comments

Comments
 (0)