Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Rename output_emitted_token_num -> output_emitted_draft_token_num #977

Merged
merged 2 commits into from
Mar 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
at::Tensor target_probs, at::Tensor output_token_ids,
at::Tensor output_accepted_token_num,
at::Tensor output_emitted_token_num, bool deterministic,
at::Tensor output_emitted_draft_token_num, bool deterministic,
std::optional<at::Generator> gen);

//========== Torch Library ==========
Expand Down
2 changes: 1 addition & 1 deletion csrc/flashinfer_sampling_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
at::Tensor target_probs, at::Tensor output_token_ids,
at::Tensor output_accepted_token_num,
at::Tensor output_emitted_token_num, bool deterministic,
at::Tensor output_emitted_draft_token_num, bool deterministic,
std::optional<at::Generator> gen);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
Expand Down
8 changes: 4 additions & 4 deletions csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
at::Tensor target_probs, at::Tensor output_token_ids,
at::Tensor output_accepted_token_num,
at::Tensor output_emitted_token_num, bool deterministic,
at::Tensor output_emitted_draft_token_num, bool deterministic,
std::optional<at::Generator> gen_) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
Expand All @@ -205,7 +205,7 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i
CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1));
CHECK_EQ(vocab_size, target_probs.size(2));
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
CHECK_EQ(batch_size, output_emitted_draft_token_num.size(0));
uint64_t philox_seed, philox_offset;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
Expand All @@ -221,8 +221,8 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
static_cast<float*>(target_probs.data_ptr()), static_cast<int*>(output_token_ids.data_ptr()),
static_cast<int*>(output_accepted_token_num.data_ptr()),
static_cast<int*>(output_emitted_token_num.data_ptr()), batch_size, num_speculate_tokens,
vocab_size, deterministic, philox_seed, philox_offset, stream);
static_cast<int*>(output_emitted_draft_token_num.data_ptr()), batch_size,
num_speculate_tokens, vocab_size, deterministic, philox_seed, philox_offset, stream);

TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));
Expand Down
41 changes: 24 additions & 17 deletions flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,17 @@ def _fake_top_k_mask_logits(

@register_custom_op(
"flashinfer::chain_speculative_sampling",
mutates_args=("output_accepted_token_num", "output_emitted_token_num"),
mutates_args=(
"output_accepted_token_num",
"output_emitted_draft_token_num",
),
)
def chain_speculative_sampling(
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
target_probs: torch.Tensor,
output_accepted_token_num: torch.Tensor,
output_emitted_token_num: torch.Tensor,
output_emitted_draft_token_num: torch.Tensor,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
Expand All @@ -349,7 +352,7 @@ def chain_speculative_sampling(
draft_token_ids = draft_token_ids.int()
target_probs = target_probs.float()
output_accepted_token_num = output_accepted_token_num.int()
output_emitted_token_num = output_emitted_token_num.int()
output_emitted_draft_token_num = output_emitted_draft_token_num.int()
b, n = draft_token_ids.shape
output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device)
module.chain_speculative_sampling.default(
Expand All @@ -358,7 +361,7 @@ def chain_speculative_sampling(
target_probs,
output_token_ids,
output_accepted_token_num,
output_emitted_token_num,
output_emitted_draft_token_num,
deterministic,
generator,
)
Expand All @@ -370,7 +373,7 @@ def _fake_chain_speculative_sampling(
draft_token_ids: torch.Tensor,
target_probs: torch.Tensor,
output_accepted_token_num: torch.Tensor,
output_emitted_token_num: torch.Tensor,
output_emitted_draft_token_num: torch.Tensor,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
Expand Down Expand Up @@ -1130,7 +1133,7 @@ def chain_speculative_sampling(
draft_token_ids,
target_probs,
maybe_output_accepted_token_num: Optional[torch.Tensor] = None,
maybe_output_emitted_token_num: Optional[torch.Tensor] = None,
maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -1158,8 +1161,10 @@ def chain_speculative_sampling(
It only evaluates the alignment of draft model and target model.
Shape: ``(batch_size)``
If specified, the number of accepted token number will be added to this tensor inplace. Default is ``None``.
maybe_output_emitted_token_num: Optional[torch.Tensor]
The number of tokens that are finally emitted/generated for each request.
maybe_output_emitted_draft_token_num: Optional[torch.Tensor]
The number of draft tokens that are finally emitted for each request. Does not include
the bonus token. (Thus the total number of tokens sampled for a given request is
output_emitted_draft_token_num + 1).
Shape: ``(batch_size)``
If specified, the number of emitted token number will be added to this tensor inplace. Default is ``None``.
deterministic: bool
Expand All @@ -1182,8 +1187,10 @@ def chain_speculative_sampling(
satisfy the probability requirement r < p/q.
It only evaluates the alignment of draft model and target model.
Shape: ``(batch_size)``
output_emitted_token_num: torch.Tensor
The number of tokens that are finally emitted/generated for each request.
output_emitted_draft_token_num: torch.Tensor
The number of draft tokens that are finally emitted for each request. Does not include
the bonus token. (Thus the total number of tokens sampled for a given request is
output_emitted_draft_token_num + 1).
Shape: ``(batch_size)``

Examples
Expand All @@ -1200,7 +1207,7 @@ def chain_speculative_sampling(
>>> # token 1 was sampled from draft model for the second token
>>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0)
>>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0)
>>> output_token_ids, output_accepted_token_num, output_accepted_token_num =\
>>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\
... flashinfer.sampling.chain_speculative_sampling(
... draft_probs, draft_token_ids, target_probs)
>>> # the first token is accepted, the second token is rejected and sampled from the difference
Expand All @@ -1209,7 +1216,7 @@ def chain_speculative_sampling(
tensor([[ 2, 0, -1]], device='cuda:0', dtype=torch.int32)
>>> output_accepted_token_num
tensor([1], device='cuda:0')
>>> output_emitted_token_num
>>> output_emitted_draft_token_num
tensor([1], device='cuda:0')
"""
b = draft_probs.size(0)
Expand All @@ -1218,17 +1225,17 @@ def chain_speculative_sampling(
output_accepted_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
else:
output_accepted_token_num = maybe_output_accepted_token_num
if maybe_output_emitted_token_num is None:
output_emitted_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
if maybe_output_emitted_draft_token_num is None:
output_emitted_draft_token_num = torch.zeros(b, dtype=torch.int32, device=dev)
else:
output_emitted_token_num = maybe_output_emitted_token_num
output_emitted_draft_token_num = maybe_output_emitted_draft_token_num
output_token_ids = get_sampling_module().chain_speculative_sampling(
draft_probs,
draft_token_ids,
target_probs,
output_accepted_token_num,
output_emitted_token_num,
output_emitted_draft_token_num,
deterministic,
generator,
)
return output_token_ids, output_accepted_token_num, output_emitted_token_num
return output_token_ids, output_accepted_token_num, output_emitted_draft_token_num
8 changes: 4 additions & 4 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
__global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
DType* target_probs, IdType* output_token_ids,
IdType* output_accepted_token_num,
IdType* output_emitted_token_num,
IdType* output_emitted_draft_token_num,
uint32_t num_speculative_tokens, uint32_t d,
uint64_t philox_seed, uint64_t philox_offset) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
Expand Down Expand Up @@ -1427,7 +1427,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token

if (tx == 0) {
output_accepted_token_num[row_idx] += accepted_token_num;
output_emitted_token_num[row_idx] += emitted_token_num;
output_emitted_draft_token_num[row_idx] += emitted_token_num;
}

// sample from relu(target_probs - draft_probs)
Expand Down Expand Up @@ -1517,7 +1517,7 @@ template <typename DType, typename IdType>
cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
DType* target_probs, IdType* output_token_ids,
IdType* output_accepted_token_num,
IdType* output_emitted_token_num, uint32_t batch_size,
IdType* output_emitted_draft_token_num, uint32_t batch_size,
uint32_t num_speculative_tokens, uint32_t d,
bool deterministic, uint64_t philox_seed,
uint64_t philox_offset, cudaStream_t stream = 0) {
Expand All @@ -1532,7 +1532,7 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids
&target_probs,
&output_token_ids,
&output_accepted_token_num,
&output_emitted_token_num,
&output_emitted_draft_token_num,
&num_speculative_tokens,
&d,
&philox_seed,
Expand Down