Skip to content

Commit e7523c2

Browse files
authored
[V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs (#18608)
1 parent a869bac commit e7523c2

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,18 @@ def forward_cuda(
8989
p: Optional[torch.Tensor],
9090
) -> torch.Tensor:
9191
"""More optimized implementation for top-k and top-p sampling."""
92-
probs = logits.softmax(dim=-1, dtype=torch.float32)
9392
if k is None and p is None:
9493
# We prefer `random_sample` over `flashinfer_sample` when sorting is
9594
# not needed. This is because `random_sample` does not require
9695
# CPU-GPU synchronization while `flashinfer_sample` does.
96+
probs = logits.softmax(dim=-1, dtype=torch.float32)
9797
return random_sample(probs, generators)
9898
if generators:
9999
logger.warning("FlashInfer 0.2.3+ does not support "
100100
"per-request generators. Falling back to "
101101
"PyTorch-native implementation.")
102102
return self.forward_native(logits, generators, k, p)
103-
return flashinfer_sample(probs, k, p, generators)
103+
return flashinfer_sample(logits, k, p, generators)
104104

105105
def forward_tpu(
106106
self,
@@ -254,17 +254,17 @@ def random_sample(
254254

255255

256256
def flashinfer_sample(
257-
probs: torch.Tensor,
257+
logits: torch.Tensor,
258258
k: Optional[torch.Tensor],
259259
p: Optional[torch.Tensor],
260260
generators: dict[int, torch.Generator],
261261
) -> torch.Tensor:
262-
"""Sample from the probabilities using FlashInfer.
262+
"""Sample from the logits using FlashInfer.
263263
264264
Statistically, this function is equivalent to the `random_sample` function.
265265
However, this function is faster because it avoids sorting the logits tensor
266266
via rejection sampling.
267-
267+
268268
NOTE: The outputs of this function do not necessarily match the outputs of
269269
the `random_sample` function. It only guarantees that the outputs are
270270
statistically equivalent.
@@ -274,18 +274,19 @@ def flashinfer_sample(
274274
the synchronization overhead.
275275
"""
276276
assert not (k is None and p is None)
277-
278277
if k is None:
279278
# Top-p only.
279+
probs = logits.softmax(dim=-1, dtype=torch.float32)
280280
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
281281
probs, p, deterministic=True)
282282
elif p is None:
283283
# Top-k only.
284+
probs = logits.softmax(dim=-1, dtype=torch.float32)
284285
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
285286
probs, k, deterministic=True)
286287
else:
287288
# Both top-k and top-p.
288-
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
289-
probs, k, p, deterministic=True))
289+
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
290+
logits, k, p, deterministic=True)
290291

291292
return next_token_ids.view(-1)

0 commit comments

Comments
 (0)