Skip to content

Commit 2eb6e2f

Browse files
author
Ubuntu
committed
[Sampler] Use FlashInfer sampling from logits
1 parent d0bc2f8 commit 2eb6e2f

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,18 @@ def forward_cuda(
8989
p: Optional[torch.Tensor],
9090
) -> torch.Tensor:
9191
"""More optimized implementation for top-k and top-p sampling."""
92-
probs = logits.softmax(dim=-1, dtype=torch.float32)
9392
if k is None and p is None:
9493
# We prefer `random_sample` over `flashinfer_sample` when sorting is
9594
# not needed. This is because `random_sample` does not require
9695
# CPU-GPU synchronization while `flashinfer_sample` does.
96+
probs = logits.softmax(dim=-1, dtype=torch.float32)
9797
return random_sample(probs, generators)
9898
if generators:
9999
logger.warning("FlashInfer 0.2.3+ does not support "
100100
"per-request generators. Falling back to "
101101
"PyTorch-native implementation.")
102102
return self.forward_native(logits, generators, k, p)
103-
return flashinfer_sample(probs, k, p, generators)
103+
return flashinfer_sample(logits, k, p, generators)
104104

105105
def forward_tpu(
106106
self,
@@ -254,7 +254,7 @@ def random_sample(
254254

255255

256256
def flashinfer_sample(
257-
probs: torch.Tensor,
257+
logits: torch.Tensor,
258258
k: Optional[torch.Tensor],
259259
p: Optional[torch.Tensor],
260260
generators: dict[int, torch.Generator],
@@ -264,7 +264,7 @@ def flashinfer_sample(
264264
Statistically, this function is equivalent to the `random_sample` function.
265265
However, this function is faster because it avoids sorting the logits tensor
266266
via rejection sampling.
267-
267+
268268
NOTE: The outputs of this function do not necessarily match the outputs of
269269
the `random_sample` function. It only guarantees that the outputs are
270270
statistically equivalent.
@@ -274,18 +274,19 @@ def flashinfer_sample(
274274
the synchronization overhead.
275275
"""
276276
assert not (k is None and p is None)
277-
278277
if k is None:
279278
# Top-p only.
279+
probs = logits.softmax(dim=-1, dtype=torch.float32)
280280
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
281281
probs, p, deterministic=True)
282282
elif p is None:
283283
# Top-k only.
284+
probs = logits.softmax(dim=-1, dtype=torch.float32)
284285
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
285286
probs, k, deterministic=True)
286287
else:
287288
# Both top-k and top-p.
288-
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
289-
probs, k, p, deterministic=True))
289+
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
290+
logits, k, p, deterministic=True)
290291

291292
return next_token_ids.view(-1)

0 commit comments

Comments
 (0)