Skip to content

[V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs #18608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ def forward_cuda(
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
if generators:
logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(probs, k, p, generators)
return flashinfer_sample(logits, k, p, generators)

def forward_tpu(
self,
Expand Down Expand Up @@ -254,17 +254,17 @@ def random_sample(


def flashinfer_sample(
probs: torch.Tensor,
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
"""Sample from the logits using FlashInfer.

Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.

NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
Expand All @@ -274,18 +274,19 @@ def flashinfer_sample(
the synchronization overhead.
"""
assert not (k is None and p is None)

if k is None:
# Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True)
elif p is None:
# Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, k, p, deterministic=True))
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, deterministic=True)

return next_token_ids.view(-1)