Skip to content

Commit 0dd801d

Browse files
authored
feat: deterministic sampling (#417)
Our previous sampling kernels relies on cub's BlockScan, which is not deterministic as reported in NVIDIA/cub#454. This PR implements the deterministic BlockScan using Belloch scan algorithm, which is slower than cub but guarantees determinism.
1 parent 146c31e commit 0dd801d

File tree

7 files changed

+283
-115
lines changed

7 files changed

+283
-115
lines changed

Diff for: include/flashinfer/sampling.cuh

+203-71
Large diffs are not rendered by default.

Diff for: python/csrc/flashinfer_ops.h

+9-6
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,29 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
5454

5555
std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s);
5656

57-
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples);
57+
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
58+
bool deterministic);
5859

5960
std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
60-
torch::Tensor uniform_samples, double top_p);
61+
torch::Tensor uniform_samples, double top_p,
62+
bool deterministic);
6163

6264
std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
6365
torch::Tensor uniform_samples,
64-
unsigned int top_k);
66+
unsigned int top_k, bool deterministic);
6567

6668
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
6769
torch::Tensor uniform_samples,
68-
torch::Tensor top_k,
69-
torch::Tensor top_p);
70+
torch::Tensor top_k, torch::Tensor top_p,
71+
bool deterministic);
7072

7173
torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);
7274

7375
torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);
7476

7577
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
76-
torch::Tensor uniform_samples, torch::Tensor target_probs);
78+
torch::Tensor uniform_samples, torch::Tensor target_probs,
79+
bool deterministic);
7780

7881
torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);
7982

Diff for: python/csrc/sampling.cu

+17-14
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
using namespace flashinfer;
2222

23-
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples) {
23+
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
24+
bool deterministic) {
2425
CHECK_INPUT(probs);
2526
CHECK_INPUT(uniform_samples);
2627
auto device = probs.device();
@@ -36,16 +37,18 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
3637
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
3738
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
3839

39-
cudaError_t status = sampling::SamplingFromProb(
40-
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
41-
static_cast<int*>(samples.data_ptr()), batch_size, vocab_size, torch_current_stream);
40+
cudaError_t status = sampling::SamplingFromProb(static_cast<float*>(probs.data_ptr()),
41+
static_cast<float*>(uniform_samples.data_ptr()),
42+
static_cast<int*>(samples.data_ptr()), batch_size,
43+
vocab_size, deterministic, torch_current_stream);
4244
TORCH_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
4345
std::string(cudaGetErrorString(status)));
4446
return samples;
4547
}
4648

4749
std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
48-
torch::Tensor uniform_samples, double top_p) {
50+
torch::Tensor uniform_samples, double top_p,
51+
bool deterministic) {
4952
CHECK_INPUT(probs);
5053
CHECK_INPUT(uniform_samples);
5154
auto device = probs.device();
@@ -66,7 +69,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
6669
cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
6770
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
6871
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_p,
69-
batch_size, vocab_size, max_top_p_rounds, torch_current_stream);
72+
batch_size, vocab_size, max_top_p_rounds, deterministic, torch_current_stream);
7073
TORCH_CHECK(status == cudaSuccess, "TopPSamplingFromProbs failed with error code " +
7174
std::string(cudaGetErrorString(status)));
7275

@@ -75,7 +78,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
7578

7679
std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
7780
torch::Tensor uniform_samples,
78-
unsigned int top_k) {
81+
unsigned int top_k, bool deterministic) {
7982
CHECK_INPUT(probs);
8083
CHECK_INPUT(uniform_samples);
8184
auto device = probs.device();
@@ -96,7 +99,7 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
9699
cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
97100
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
98101
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_k,
99-
batch_size, vocab_size, max_top_k_rounds, torch_current_stream);
102+
batch_size, vocab_size, max_top_k_rounds, deterministic, torch_current_stream);
100103
TORCH_CHECK(status == cudaSuccess, "TopKSamplingFromProbs failed with error code " +
101104
std::string(cudaGetErrorString(status)));
102105

@@ -105,8 +108,8 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
105108

106109
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
107110
torch::Tensor uniform_samples,
108-
torch::Tensor top_k,
109-
torch::Tensor top_p) {
111+
torch::Tensor top_k, torch::Tensor top_p,
112+
bool deterministic) {
110113
CHECK_INPUT(probs);
111114
CHECK_INPUT(uniform_samples);
112115
CHECK_INPUT(top_k);
@@ -138,7 +141,7 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
138141
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
139142
static_cast<int*>(top_k.data_ptr()), static_cast<float*>(top_p.data_ptr()),
140143
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
141-
vocab_size, max_rounds, torch_current_stream);
144+
vocab_size, max_rounds, deterministic, torch_current_stream);
142145
TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " +
143146
std::string(cudaGetErrorString(status)));
144147

@@ -187,8 +190,8 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double
187190
}
188191

189192
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
190-
torch::Tensor uniform_samples,
191-
torch::Tensor target_probs) {
193+
torch::Tensor uniform_samples, torch::Tensor target_probs,
194+
bool deterministic) {
192195
CHECK_INPUT(draft_probs);
193196
CHECK_INPUT(draft_token_ids);
194197
CHECK_INPUT(uniform_samples);
@@ -224,7 +227,7 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
224227
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
225228
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
226229
static_cast<int*>(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size,
227-
torch_current_stream);
230+
deterministic, torch_current_stream);
228231

229232
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
230233
std::string(cudaGetErrorString(status)));

Diff for: python/flashinfer/sampling.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333

3434
def sampling_from_probs(
35-
probs: torch.Tensor, uniform_samples: torch.Tensor
35+
probs: torch.Tensor, uniform_samples: torch.Tensor, deterministic: bool = True
3636
) -> torch.Tensor:
3737
r"""Fused GPU kernel for category sampling from probabilities.
3838
@@ -43,6 +43,8 @@ def sampling_from_probs(
4343
uniform_samples: torch.Tensor
4444
The uniform samples used as needle for sampling, shape ``(batch_size,)``.
4545
Expected to be uniformly distributed in ``[0, 1)``.
46+
deterministic: bool
47+
Whether to use deterministic kernel implementation, default is ``True``.
4648
4749
Returns
4850
-------
@@ -73,11 +75,14 @@ def sampling_from_probs(
7375
-----
7476
This function expects float32 inputs, and the output is int32.
7577
"""
76-
return _kernels.sampling_from_probs(probs, uniform_samples)
78+
return _kernels.sampling_from_probs(probs, uniform_samples, deterministic)
7779

7880

7981
def top_p_sampling_from_probs(
80-
probs: torch.Tensor, uniform_samples: torch.Tensor, top_p: float
82+
probs: torch.Tensor,
83+
uniform_samples: torch.Tensor,
84+
top_p: float,
85+
deterministic: bool = True,
8186
) -> Tuple[torch.Tensor, torch.Tensor]:
8287
r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
8388
this operator implements GPU-based rejection sampling without explicit sorting.
@@ -95,6 +100,8 @@ def top_p_sampling_from_probs(
95100
Expected to be uniformly distributed in ``[0, 1)``.
96101
top_p: float
97102
The threshold for top-p sampling.
103+
deterministic: bool
104+
Whether to use deterministic kernel implementation, default is ``True``.
98105
99106
Returns
100107
-------
@@ -134,11 +141,16 @@ def top_p_sampling_from_probs(
134141
We encourage users to set ``max_top_p_rounds`` to a reasonable value, e.g., 32. The actual
135142
implementation usually use much fewer rounds for rejection sampling because of early stopping.
136143
"""
137-
return _kernels.top_p_sampling_from_probs(probs, uniform_samples, top_p)
144+
return _kernels.top_p_sampling_from_probs(
145+
probs, uniform_samples, top_p, deterministic
146+
)
138147

139148

140149
def top_k_sampling_from_probs(
141-
probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int
150+
probs: torch.Tensor,
151+
uniform_samples: torch.Tensor,
152+
top_k: int,
153+
deterministic: bool = True,
142154
) -> Tuple[torch.Tensor, torch.Tensor]:
143155
r"""Fused GPU kernel for top-k sampling from probabilities,
144156
this operator implements GPU-based rejection sampling without explicit sorting.
@@ -156,6 +168,8 @@ def top_k_sampling_from_probs(
156168
Expected to be uniformly distributed in ``[0, 1)``.
157169
top_k: int
158170
The k in "top-k".
171+
deterministic: bool
172+
Whether to use deterministic kernel implementation, default is ``True``.
159173
160174
Returns
161175
-------
@@ -195,14 +209,17 @@ def top_k_sampling_from_probs(
195209
We encourage users to set ``max_top_k_rounds`` to a reasonable value, e.g., 32. The actual
196210
implementation usually use much fewer rounds for rejection sampling because of early stopping.
197211
"""
198-
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)
212+
return _kernels.top_k_sampling_from_probs(
213+
probs, uniform_samples, top_k, deterministic
214+
)
199215

200216

201217
def top_k_top_p_sampling_from_probs(
202218
probs: torch.Tensor,
203219
uniform_samples: torch.Tensor,
204220
top_k: torch.Tensor,
205221
top_p: torch.Tensor,
222+
deterministic: bool = True,
206223
) -> Tuple[torch.Tensor, torch.Tensor]:
207224
r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities,
208225
@@ -223,6 +240,8 @@ def top_k_top_p_sampling_from_probs(
223240
The k in "top-k" for each request, shape ``(batch_size,)``.
224241
top_p: torch.Tensor
225242
The threshold for top-p sampling for each request, shape ``(batch_size,)``.
243+
deterministic: bool
244+
Whether to use deterministic kernel implementation, default is ``True``.
226245
227246
Returns
228247
-------
@@ -264,7 +283,7 @@ def top_k_top_p_sampling_from_probs(
264283
implementation usually use much fewer rounds for rejection sampling because of early stopping.
265284
"""
266285
return _kernels.top_k_top_p_sampling_from_probs(
267-
probs, uniform_samples, top_k, top_p
286+
probs, uniform_samples, top_k, top_p, deterministic
268287
)
269288

270289

@@ -328,6 +347,7 @@ def chain_speculative_sampling(
328347
draft_token_ids,
329348
uniform_samples,
330349
target_probs,
350+
deterministic: bool = True,
331351
) -> torch.Tensor:
332352
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
333353
paper `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_),
@@ -349,6 +369,8 @@ def chain_speculative_sampling(
349369
Compared to input :attr:`draft_probs`, the target model's probability has an additional
350370
slot at the end because the target model will generate one more token than the draft model.
351371
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``
372+
deterministic: bool
373+
Whether to use deterministic kernel implementation, default is ``True``.
352374
353375
Returns
354376
-------
@@ -361,5 +383,5 @@ def chain_speculative_sampling(
361383
Shape: (batch_size, num_specutate_tokens + 1)
362384
"""
363385
return _kernels.chain_speculative_sampling(
364-
draft_probs, draft_token_ids, uniform_samples, target_probs
386+
draft_probs, draft_token_ids, uniform_samples, target_probs, deterministic
365387
)

Diff for: src/bench_sampling.cu

+12-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ template <typename T>
2626
void bench_sampling_with_probability(nvbench::state& state) {
2727
size_t batch_size = state.get_int64("batch_size");
2828
size_t vocab_size = state.get_int64("vocab_size");
29+
bool deterministic = state.get_int64("determinisic");
2930

3031
std::vector<T> probs_h(batch_size * vocab_size);
3132
std::vector<T> uniform_samples_h(batch_size);
@@ -55,7 +56,7 @@ void bench_sampling_with_probability(nvbench::state& state) {
5556
cudaError_t status = sampling::SamplingFromProb<T>(
5657
thrust::raw_pointer_cast(probs_d.data()),
5758
thrust::raw_pointer_cast(uniform_samples_d.data()),
58-
thrust::raw_pointer_cast(output_d.data()), batch_size, vocab_size);
59+
thrust::raw_pointer_cast(output_d.data()), batch_size, vocab_size, deterministic);
5960
timer.stop();
6061
if (status != cudaSuccess) {
6162
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
@@ -67,6 +68,7 @@ template <typename T>
6768
void bench_top_p_sampling_with_probability(nvbench::state& state) {
6869
size_t batch_size = state.get_int64("batch_size");
6970
size_t vocab_size = state.get_int64("vocab_size");
71+
bool deterministic = state.get_int64("determinisic");
7072
double p = state.get_float64("p");
7173
constexpr uint32_t max_top_p_rounds = 32;
7274

@@ -100,7 +102,7 @@ void bench_top_p_sampling_with_probability(nvbench::state& state) {
100102
thrust::raw_pointer_cast(probs_d.data()),
101103
thrust::raw_pointer_cast(uniform_samples_d.data()),
102104
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), p,
103-
batch_size, vocab_size, max_top_p_rounds);
105+
batch_size, vocab_size, max_top_p_rounds, deterministic);
104106
timer.stop();
105107
if (status != cudaSuccess) {
106108
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
@@ -113,6 +115,7 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) {
113115
size_t batch_size = state.get_int64("batch_size");
114116
size_t vocab_size = state.get_int64("vocab_size");
115117
size_t k = state.get_int64("k");
118+
bool deterministic = state.get_int64("determinisic");
116119
constexpr uint32_t max_top_k_rounds = 32;
117120

118121
std::vector<T> probs_h(batch_size * vocab_size);
@@ -145,7 +148,7 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) {
145148
thrust::raw_pointer_cast(probs_d.data()),
146149
thrust::raw_pointer_cast(uniform_samples_d.data()),
147150
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), k,
148-
batch_size, vocab_size, max_top_k_rounds);
151+
batch_size, vocab_size, max_top_k_rounds, deterministic);
149152
timer.stop();
150153
if (status != cudaSuccess) {
151154
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
@@ -157,18 +160,21 @@ auto bench_sampling_with_probability_f32 = bench_sampling_with_probability<float
157160
NVBENCH_BENCH(bench_sampling_with_probability_f32)
158161
.set_name("bench_sampling_with_probability_f32")
159162
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
160-
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000});
163+
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
164+
.add_int64_axis("determinisic", {0, 1});
161165

162166
auto bench_top_p_sampling_with_probability_f32 = bench_top_p_sampling_with_probability<float>;
163167
NVBENCH_BENCH(bench_top_p_sampling_with_probability_f32)
164168
.set_name("bench_top_p_sampling_with_probability_f32")
165169
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
166170
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
167-
.add_float64_axis("p", {0.1, 0.5, 0.9, 1.0});
171+
.add_float64_axis("p", {0.1, 0.5, 0.9, 1.0})
172+
.add_int64_axis("determinisic", {0, 1});
168173

169174
auto bench_top_k_sampling_with_probability_f32 = bench_top_k_sampling_with_probability<float>;
170175
NVBENCH_BENCH(bench_top_k_sampling_with_probability_f32)
171176
.set_name("bench_top_k_sampling_with_probability_f32")
172177
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
173178
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
174-
.add_int64_axis("k", {16, 32, 128, 1024});
179+
.add_int64_axis("k", {16, 32, 128, 1024})
180+
.add_int64_axis("determinisic", {0, 1});

0 commit comments

Comments
 (0)