Skip to content

Commit cea2bb9

Browse files
authored
sampling: fused speculative sampling kernels (#259)
- <del>[ ] Fused chain verify sampling </del> (left for next pr). - [x] Fused tree verify sampling - [x] Renorm top-p, top-k
1 parent 7e9cc7f commit cea2bb9

File tree

10 files changed

+680
-21
lines changed

10 files changed

+680
-21
lines changed

Diff for: include/flashinfer/sampling.cuh

+368-3
Large diffs are not rendered by default.

Diff for: python/csrc/batch_prefill.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
3838

3939
cudaError_t status =
4040
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
41-
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
42-
batch_size, num_qo_heads, num_kv_heads, head_dim);
41+
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
42+
batch_size, num_qo_heads, num_kv_heads, head_dim);
4343
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
4444
cudaGetErrorString(status));
4545
}
@@ -166,8 +166,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
166166

167167
cudaError_t status =
168168
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
169-
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
170-
batch_size, num_qo_heads, num_kv_heads, head_dim);
169+
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
170+
batch_size, num_qo_heads, num_kv_heads, head_dim);
171171
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
172172
cudaGetErrorString(status));
173173
}

Diff for: python/csrc/flashinfer_ops.cu

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3434
"Top-k sampling from probabilities");
3535
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
3636
"Top-p sampling from probabilities");
37+
m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask");
38+
m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask");
39+
m.def("chain_speculative_sampling", &chain_speculative_sampling,
40+
"Speculative sampling from sequence of probabilities");
3741
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
3842
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
3943
"BatchDecodeWithPagedKVCachePyTorchWrapper")

Diff for: python/csrc/flashinfer_ops.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
6262
torch::Tensor uniform_samples,
6363
unsigned int top_k);
6464

65+
torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);
66+
67+
torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);
68+
69+
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
70+
torch::Tensor uniform_samples, torch::Tensor target_probs);
71+
6572
torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);
6673

6774
class BatchDecodeWithPagedKVCachePyTorchWrapper {
@@ -83,8 +90,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
8390
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
8491
unsigned int max_workspace_size_in_bytes)
8592
: kv_layout_(flashinfer::QKVLayout(layout)),
86-
handler_(
87-
std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}
93+
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}
8894

8995
protected:
9096
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;

Diff for: python/csrc/sampling.cu

+82
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,85 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
9696

9797
return {samples, success};
9898
}
99+
100+
torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) {
101+
CHECK_INPUT(probs);
102+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
103+
unsigned int batch_size = probs.size(0);
104+
unsigned int vocab_size = probs.size(1);
105+
probs = probs.to(torch::kFloat32);
106+
107+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
108+
auto renorm_probs =
109+
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(probs.device()));
110+
111+
cudaError_t status = sampling::TopPRenormProb<float>(
112+
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()), top_p,
113+
eps, batch_size, vocab_size, torch_current_stream);
114+
TORCH_CHECK(status == cudaSuccess,
115+
"TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
116+
return renorm_probs;
117+
}
118+
119+
torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps) {
120+
CHECK_INPUT(probs);
121+
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
122+
unsigned int batch_size = probs.size(0);
123+
unsigned int vocab_size = probs.size(1);
124+
probs = probs.to(torch::kFloat32);
125+
126+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
127+
auto renorm_probs =
128+
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(probs.device()));
129+
130+
cudaError_t status = sampling::TopKRenormProb<float>(
131+
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()), top_k,
132+
eps, batch_size, vocab_size, torch_current_stream);
133+
134+
TORCH_CHECK(status == cudaSuccess,
135+
"TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
136+
return renorm_probs;
137+
}
138+
139+
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
140+
torch::Tensor uniform_samples,
141+
torch::Tensor target_probs) {
142+
CHECK_INPUT(draft_probs);
143+
CHECK_INPUT(draft_token_ids);
144+
CHECK_INPUT(uniform_samples);
145+
CHECK_INPUT(target_probs);
146+
CHECK_DIM(3, draft_probs); // draft_probs: (batch_size, num_speculate_tokens, vocab_size)
147+
CHECK_DIM(2, draft_token_ids); // draft_token_ids: (batch_size, num_speculate_tokens)
148+
CHECK_DIM(2, uniform_samples); // uniform_samples: (batch_size, num_speculate_tokens + 1)
149+
CHECK_DIM(3, target_probs); // target_probs: (batch_size, num_speculate_tokens + 1, vocab_size)
150+
unsigned int batch_size = draft_probs.size(0);
151+
unsigned int num_speculate_tokens = draft_probs.size(1);
152+
unsigned int vocab_size = draft_probs.size(2);
153+
CHECK_EQ(batch_size, draft_token_ids.size(0));
154+
CHECK_EQ(batch_size, uniform_samples.size(0));
155+
CHECK_EQ(batch_size, target_probs.size(0));
156+
CHECK_EQ(num_speculate_tokens + 1, uniform_samples.size(1));
157+
CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1));
158+
CHECK_EQ(vocab_size, target_probs.size(2));
159+
160+
draft_probs = draft_probs.to(torch::kFloat32);
161+
draft_token_ids = draft_token_ids.to(torch::kInt32);
162+
uniform_samples = uniform_samples.to(torch::kFloat32);
163+
target_probs = target_probs.to(torch::kFloat32);
164+
165+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
166+
auto output_token_ids =
167+
torch::empty({batch_size, num_speculate_tokens + 1},
168+
torch::dtype(torch::kInt32).device(draft_token_ids.device()));
169+
170+
cudaError_t status = sampling::ChainSpeculativeSampling<float, int>(
171+
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
172+
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
173+
static_cast<int*>(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size,
174+
torch_current_stream);
175+
176+
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
177+
std::string(cudaGetErrorString(status)));
178+
179+
return output_token_ids;
180+
}

Diff for: python/flashinfer/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
sampling_from_probs,
4141
top_p_sampling_from_probs,
4242
top_k_sampling_from_probs,
43+
top_p_renorm_prob,
44+
top_k_renorm_prob,
45+
chain_speculative_sampling,
4346
)
4447
from .norm import rmsnorm
4548

Diff for: python/flashinfer/decode.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -632,9 +632,9 @@ def forward_return_lse(
632632

633633

634634
class CUDAGraphBatchDecodeWithPagedKVCacheWrapper:
635-
r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first
636-
proposed in `vLLM <https://arxiv.org/abs/2309.06180>`_) for batch of requests.
637-
635+
r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first
636+
proposed in `vLLM <https://arxiv.org/abs/2309.06180>`_) for batch of requests.
637+
638638
Note that this wrapper may not be as efficient as :class:`BatchDecodeWithPagedKVCacheWrapper`
639639
because we won't dispatch to different kernels for different batch sizes/sequence lengths/etc
640640
to accomodate the CUDAGraph requirement.
@@ -673,7 +673,7 @@ def __init__(
673673
during the lifecycle of this wrapper.
674674
indices_buffer : torch.Tensor
675675
The user reserved buffer on GPU to store the page indices of the paged kv cache,
676-
should be large enough to store the maximum number of page indices
676+
should be large enough to store the maximum number of page indices
677677
(``max_num_pages``) during the lifecycle of this wrapper.
678678
last_page_len_buffer : torch.Tensor
679679
The user reserved buffer on GPU to store the number of entries in the last page,

Diff for: python/flashinfer/sampling.py

+104-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor):
33-
r"""Category sampling from probabilities.
33+
r"""Fused GPU kernel for category sampling from probabilities.
3434
3535
Parameters
3636
----------
@@ -75,8 +75,11 @@ def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor):
7575
def top_p_sampling_from_probs(
7676
probs: torch.Tensor, uniform_samples: torch.Tensor, top_p: float
7777
):
78-
r"""Top-p sampling (nucleus sampling) from probabilities, this operator implements
79-
GPU-based rejection sampling without explicit sorting.
78+
r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
79+
this operator implements GPU-based rejection sampling without explicit sorting.
80+
81+
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
82+
which is more efficient than the naive implementation that launches a series of kernels.
8083
8184
Parameters
8285
----------
@@ -134,8 +137,11 @@ def top_p_sampling_from_probs(
134137
def top_k_sampling_from_probs(
135138
probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int
136139
):
137-
r"""Top-k sampling from probabilities, this operator implements GPU-based rejection sampling
138-
without explicit sorting.
140+
r"""Fused GPU kernel for top-k sampling from probabilities,
141+
this operator implements GPU-based rejection sampling without explicit sorting.
142+
143+
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
144+
which is more efficient than the naive implementation that launches a series of kernels.
139145
140146
Parameters
141147
----------
@@ -188,3 +194,96 @@ def top_k_sampling_from_probs(
188194
implementation usually use much fewer rounds for rejection sampling because of early stopping.
189195
"""
190196
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)
197+
198+
199+
def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5):
200+
r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding.
201+
202+
Parameters
203+
----------
204+
probs: torch.Tensor
205+
Probabilities, shape ``(batch_size, num_classes)``.
206+
top_p: float
207+
The threshold for re-normalizing probabilities, should be in ``(0, 1)``.
208+
We mask out the probabilities less than `threshold` where the cumulative sum
209+
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
210+
eps: float
211+
The epsilon value for numerical stability.
212+
213+
Returns
214+
-------
215+
renorm_probs: torch.Tensor
216+
Renormalized probabilities, shape ``(batch_size, num_classes)``.
217+
218+
This combination of ``top_p_renorm_prob`` and ``sampling_from_probs`` should be equivalent to
219+
``top_p_sampling_from_probs``.
220+
"""
221+
return _kernels.top_p_renorm_prob(probs, top_p, eps)
222+
223+
224+
def top_k_renorm_prob(probs: torch.Tensor, top_k: int, eps: float = 1e-5):
225+
r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding.
226+
227+
Parameters
228+
----------
229+
probs: torch.Tensor
230+
Probabilities, shape ``(batch_size, num_classes)``.
231+
top_k: int
232+
The threshold for re-normalizing probabilities, should be in ``(0, num_classes)``.
233+
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
234+
eps: float
235+
The epsilon value for numerical stability.
236+
237+
Returns
238+
-------
239+
renorm_probs: torch.Tensor
240+
Renormalized probabilities, shape ``(batch_size, num_classes)``.
241+
242+
Note
243+
----
244+
This combination of ``top_k_renorm_prob`` and ``sampling_from_probs`` should be equivalent to
245+
``top_k_sampling_from_probs``.
246+
"""
247+
return _kernels.top_k_renorm_prob(probs, top_k, eps)
248+
249+
250+
def chain_speculative_sampling(
251+
draft_probs,
252+
draft_token_ids,
253+
uniform_samples,
254+
target_probs,
255+
):
256+
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
257+
paper `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_),
258+
where the draft model generates a sequence(chain) of tokens for each request.
259+
260+
Parameters
261+
----------
262+
draft_probs: torch.Tensor
263+
The probability over vocabulary generated by draft model.
264+
Shape: ``(batch_size, num_speculate_tokens, vocab_size)``
265+
draft_token_ids: torch.Tensor
266+
The draft model's generated token indices.
267+
Shape: ``(batch_size, num_specutate_tokens)``
268+
uniform_samples: torch.Tensor
269+
The uniform samples used as needle for sampling, shape ``(batch_size, num_speculate_tokens + 1)``.
270+
Expected to be uniformly distributed in ``[0, 1)``.
271+
target_probs: torch.Tensor
272+
The probability over vocabulary generated by target model.
273+
Compared to input :attr:`draft_probs`, the target model's probability has an additional
274+
slot at the end because the target model will generate one more token than the draft model.
275+
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``
276+
277+
Returns
278+
-------
279+
output_token_ids: torch.Tensor
280+
The output token indices verified by the target model, rejected samples are
281+
padded with ``-1``.
282+
Compared to input :attr:`draft_token_ids`, the output tensor has an additional
283+
token index at the end for the final token, if all previous tokens are accepted,
284+
another "bonus" token will be sampled from the target model's probability.
285+
Shape: (batch_size, num_specutate_tokens + 1)
286+
"""
287+
return _kernels.chain_speculative_sampling(
288+
draft_probs, draft_token_ids, uniform_samples, target_probs
289+
)

Diff for: python/setup.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ def get_instantiation_cu() -> List[str]:
6464
(root / prefix).mkdir(parents=True, exist_ok=True)
6565

6666
group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,6,8").split(",")
67-
page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1,16,32").split(",")
68-
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
67+
page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1").split(",")
68+
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "128,256").split(",")
6969
kv_layouts = os.environ.get("FLASHINFER_KV_LAYOUTS", "0,1").split(",")
7070
pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0,1,2").split(
7171
","
7272
)
7373
allow_fp16_qk_reduction_options = os.environ.get(
74-
"FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0,1"
74+
"FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0"
7575
).split(",")
7676
causal_options = os.environ.get("FLASHINFER_CAUSAL_OPTIONS", "0,1").split(",")
7777
# dispatch.inc

0 commit comments

Comments
 (0)