Skip to content

Commit fb578e7

Browse files
authored
feat: improve sampling algorithm robustness (#923)
1. add a parameter `indices` to sampling APIs to allow multiple samples from the same probability distribution. 2. address the abnormal case of sum(probs) < u, where our earlier implementation might return an impossible value. 3. add unittest for comparing the sampled result distribution (with a lot of trials) and original distribution.
1 parent 0daed1a commit fb578e7

File tree

6 files changed

+688
-438
lines changed

6 files changed

+688
-438
lines changed

csrc/flashinfer_ops.cu

+24-15
Original file line numberDiff line numberDiff line change
@@ -155,25 +155,34 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r
155155

156156
//========== sampling ==========
157157

158-
void sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
159-
bool deterministic, int64_t cuda_stream);
158+
void sampling_from_probs(at::Tensor probs, at::Tensor output,
159+
std::optional<at::Tensor> maybe_indices, bool deterministic,
160+
std::optional<at::Generator> gen, int64_t cuda_stream);
160161

161-
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
162+
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
163+
std::optional<at::Tensor> maybe_indices,
162164
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
163-
bool deterministic, int64_t cuda_stream);
165+
bool deterministic, std::optional<at::Generator> gen,
166+
int64_t cuda_stream);
164167

165-
void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
168+
void top_k_sampling_from_probs(at::Tensor probs, at::Tensor output,
169+
std::optional<at::Tensor> maybe_indices,
166170
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
167-
bool deterministic, int64_t cuda_stream);
171+
bool deterministic, std::optional<at::Generator> gen,
172+
int64_t cuda_stream);
168173

169-
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
174+
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
175+
std::optional<at::Tensor> maybe_indices,
170176
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
171-
bool deterministic, int64_t cuda_stream);
177+
bool deterministic, std::optional<at::Generator> gen,
178+
int64_t cuda_stream);
172179

173-
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples,
174-
at::Tensor samples, std::optional<at::Tensor> maybe_top_k_arr,
175-
double top_k_val, std::optional<at::Tensor> maybe_top_p_arr,
176-
double top_p_val, bool deterministic, int64_t cuda_stream);
180+
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
181+
std::optional<at::Tensor> maybe_indices,
182+
std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
183+
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
184+
bool deterministic, std::optional<at::Generator> gen,
185+
int64_t cuda_stream);
177186

178187
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
179188
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
@@ -188,10 +197,10 @@ void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
188197
int64_t cuda_stream);
189198

190199
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
191-
at::Tensor uniform_samples, at::Tensor target_probs,
192-
at::Tensor output_token_ids, at::Tensor output_accepted_token_num,
200+
at::Tensor target_probs, at::Tensor output_token_ids,
201+
at::Tensor output_accepted_token_num,
193202
at::Tensor output_emitted_token_num, bool deterministic,
194-
int64_t cuda_stream);
203+
std::optional<at::Generator> gen, int64_t cuda_stream);
195204

196205
//========== Torch Library ==========
197206

csrc/flashinfer_sampling_ops.cu

+10-5
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,30 @@
1515
*/
1616
#include "pytorch_extension_utils.h"
1717

18-
void sampling_from_probs(at::Tensor probs, at::Tensor samples, bool deterministic,
18+
void sampling_from_probs(at::Tensor probs, at::Tensor output,
19+
std::optional<at::Tensor> maybe_indices, bool deterministic,
1920
std::optional<at::Generator> gen, int64_t cuda_stream);
2021

21-
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
22+
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
23+
std::optional<at::Tensor> maybe_indices,
2224
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
2325
bool deterministic, std::optional<at::Generator> gen,
2426
int64_t cuda_stream);
2527

26-
void top_k_sampling_from_probs(at::Tensor probs, at::Tensor samples,
28+
void top_k_sampling_from_probs(at::Tensor probs, at::Tensor output,
29+
std::optional<at::Tensor> maybe_indices,
2730
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
2831
bool deterministic, std::optional<at::Generator> gen,
2932
int64_t cuda_stream);
3033

31-
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
34+
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
35+
std::optional<at::Tensor> maybe_indices,
3236
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
3337
bool deterministic, std::optional<at::Generator> gen,
3438
int64_t cuda_stream);
3539

36-
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
40+
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
41+
std::optional<at::Tensor> maybe_indices,
3742
std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
3843
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
3944
bool deterministic, std::optional<at::Generator> gen,

csrc/sampling.cu

+41-31
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525

2626
using namespace flashinfer;
2727

28-
void sampling_from_probs(at::Tensor probs, at::Tensor samples, bool deterministic,
28+
void sampling_from_probs(at::Tensor probs, at::Tensor output,
29+
std::optional<at::Tensor> maybe_indices, bool deterministic,
2930
std::optional<at::Generator> gen_, int64_t cuda_stream) {
3031
CHECK_INPUT(probs);
3132
auto device = probs.device();
3233
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
33-
unsigned int batch_size = probs.size(0);
34+
unsigned int batch_size = output.size(0);
3435
unsigned int vocab_size = probs.size(1);
3536

3637
uint64_t philox_seed, philox_offset;
@@ -43,20 +44,22 @@ void sampling_from_probs(at::Tensor probs, at::Tensor samples, bool deterministi
4344

4445
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
4546
cudaError_t status = sampling::SamplingFromProb(
46-
static_cast<float*>(probs.data_ptr()), static_cast<int*>(samples.data_ptr()), batch_size,
47-
vocab_size, deterministic, philox_seed, philox_offset, stream);
47+
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
48+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
49+
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
4850
TORCH_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
4951
std::string(cudaGetErrorString(status)));
5052
}
5153

52-
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
54+
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
55+
std::optional<at::Tensor> maybe_indices,
5356
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
5457
bool deterministic, std::optional<at::Generator> gen_,
5558
int64_t cuda_stream) {
5659
CHECK_INPUT(probs);
5760
auto device = probs.device();
5861
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
59-
unsigned int batch_size = probs.size(0);
62+
unsigned int batch_size = output.size(0);
6063
unsigned int vocab_size = probs.size(1);
6164
bool has_top_p_arr = maybe_top_p_arr.has_value();
6265
uint64_t philox_seed, philox_offset;
@@ -69,25 +72,26 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
6972

7073
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
7174
cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
72-
static_cast<float*>(probs.data_ptr()), static_cast<int*>(samples.data_ptr()),
75+
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
76+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
7377
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
7478
top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
7579
TORCH_CHECK(status == cudaSuccess, "TopPSamplingFromProbs failed with error code " +
7680
std::string(cudaGetErrorString(status)));
7781
}
7882

79-
void top_k_sampling_from_probs(at::Tensor probs, at::Tensor samples,
83+
void top_k_sampling_from_probs(at::Tensor probs, at::Tensor output,
84+
std::optional<at::Tensor> maybe_indices,
8085
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
8186
bool deterministic, std::optional<at::Generator> gen_,
8287
int64_t cuda_stream) {
8388
CHECK_INPUT(probs);
84-
CHECK_INPUT(samples);
89+
CHECK_INPUT(output);
8590
auto device = probs.device();
86-
CHECK_EQ(samples.device(), device);
87-
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
88-
CHECK_DIM(1, samples); // samples: (batch_size)
89-
CHECK_EQ(probs.size(0), samples.size(0));
90-
unsigned int batch_size = probs.size(0);
91+
CHECK_EQ(output.device(), device);
92+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
93+
CHECK_DIM(1, output); // output: (batch_size)
94+
unsigned int batch_size = output.size(0);
9195
unsigned int vocab_size = probs.size(1);
9296
bool has_top_k_arr = maybe_top_k_arr.has_value();
9397
uint64_t philox_seed, philox_offset;
@@ -100,24 +104,26 @@ void top_k_sampling_from_probs(at::Tensor probs, at::Tensor samples,
100104

101105
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
102106
cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
103-
static_cast<float*>(probs.data_ptr()), static_cast<int*>(samples.data_ptr()),
107+
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
108+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
104109
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
105110
top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
106111
TORCH_CHECK(status == cudaSuccess, "TopKSamplingFromProbs failed with error code " +
107112
std::string(cudaGetErrorString(status)));
108113
}
109114

110-
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
115+
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
116+
std::optional<at::Tensor> maybe_indices,
111117
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
112118
bool deterministic, std::optional<at::Generator> gen_,
113119
int64_t cuda_stream) {
114120
CHECK_INPUT(probs);
115-
CHECK_INPUT(samples);
121+
CHECK_INPUT(output);
116122
auto device = probs.device();
117-
CHECK_EQ(samples.device(), device);
118-
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
119-
CHECK_DIM(1, samples); // samples: (batch_size)
120-
unsigned int batch_size = probs.size(0);
123+
CHECK_EQ(output.device(), device);
124+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
125+
CHECK_DIM(1, output); // output: (batch_size)
126+
unsigned int batch_size = output.size(0);
121127
unsigned int vocab_size = probs.size(1);
122128
bool has_min_p_arr = maybe_min_p_arr.has_value();
123129
uint64_t philox_seed, philox_offset;
@@ -132,24 +138,26 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
132138
cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
133139
static_cast<float*>(probs.data_ptr()),
134140
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr->data_ptr()) : nullptr,
135-
static_cast<int*>(samples.data_ptr()), batch_size, min_p_val, vocab_size, deterministic,
136-
philox_seed, philox_offset, stream);
141+
static_cast<int*>(output.data_ptr()),
142+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
143+
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
137144
TORCH_CHECK(status == cudaSuccess, "MinPSamplingFromProb failed with error code " +
138145
std::string(cudaGetErrorString(status)));
139146
}
140147

141-
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
148+
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
149+
std::optional<at::Tensor> maybe_indices,
142150
std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
143151
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
144152
bool deterministic, std::optional<at::Generator> gen_,
145153
int64_t cuda_stream) {
146154
CHECK_INPUT(probs);
147-
CHECK_INPUT(samples);
155+
CHECK_INPUT(output);
148156
auto device = probs.device();
149-
CHECK_EQ(samples.device(), device);
150-
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
151-
CHECK_DIM(1, samples); // samples: (batch_size)
152-
unsigned int batch_size = probs.size(0);
157+
CHECK_EQ(output.device(), device);
158+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
159+
CHECK_DIM(1, output); // output: (batch_size)
160+
unsigned int batch_size = output.size(0);
153161
unsigned int vocab_size = probs.size(1);
154162
bool has_top_k_arr = maybe_top_k_arr.has_value();
155163
bool has_top_p_arr = maybe_top_p_arr.has_value();
@@ -166,8 +174,10 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor samples,
166174
static_cast<float*>(probs.data_ptr()),
167175
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr,
168176
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr,
169-
static_cast<int*>(samples.data_ptr()), batch_size, top_k_val, top_p_val, vocab_size,
170-
deterministic, philox_seed, philox_offset, stream);
177+
static_cast<int*>(output.data_ptr()),
178+
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
179+
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
180+
stream);
171181
TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " +
172182
std::string(cudaGetErrorString(status)));
173183
}

0 commit comments

Comments
 (0)