Skip to content

[Sampler] Adapt to FlashInfer 0.2.3 sampler API #15777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
RUN --mount=type=cache,target=/root/.cache/uv \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
# uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.4/flashinfer_python-0.2.4+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
# TESTING: install FlashInfer from source to test 2.7.0 final RC
FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/[email protected].2.post1" ; \
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/[email protected].4" ; \
fi
COPY examples examples
COPY benchmarks benchmarks
Expand Down
14 changes: 12 additions & 2 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
# @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, device: str,
Expand Down Expand Up @@ -214,7 +217,10 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
# @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode()
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
Expand Down Expand Up @@ -284,6 +290,10 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
"""

pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
"the ability to pass in uniform samples.")

torch.set_default_device(device)
torch.manual_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand Down
2 changes: 2 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,8 @@ def test_flashinfer_fallback(seed: int, device: str):
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
pytest.skip("Flashinfer sampler is disabled")

pytest.skip("After FlashInfer 0.2.3, sampling will never fail")

set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
Expand Down
72 changes: 71 additions & 1 deletion tests/v1/sample/test_topk_topp_sampler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
from torch import Generator

from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
is_flashinfer_available)

DEVICE = "cuda"

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024

FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available


def test_topk_impl_equivalance():

Expand All @@ -35,3 +41,67 @@ def test_topk_impl_equivalance():
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)

assert torch.allclose(result1, result2)


def test_flashinfer_sampler():
'''
This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation.

NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
'''

if not FLASHINFER_ENABLED:
pytest.skip(
"FlashInfer not installed or not available on this platform.")

with torch.device(DEVICE):
generator = Generator(device=DEVICE).manual_seed(42)

# Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

# Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand(
(BATCH_SIZE, ),
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]

# Sometimes disable top-k (k=vocab_size)
k_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), VOCAB_SIZE)

# Sometimes disable top-p (p=1.0)
p_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), 1.0)

python_logits = apply_top_k_top_p(
logits=logits.clone(),
k=k_values,
p=p_values,
)
python_probs = torch.softmax(python_logits, dim=-1)

# FlashInfer only exposed renorm interfaces for probs so convert first
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
flashinfer_probs = top_k_renorm_probs(
probs=flashinfer_probs,
top_k=k_values,
)
flashinfer_probs = top_p_renorm_probs(
probs=flashinfer_probs,
top_p=p_values,
)

# Compare the results
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
"FlashInfer and Python sampling implementations do not match!"
13 changes: 7 additions & 6 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ def forward(
# for rejection sampling
if self.use_flashinfer and chain_speculative_sampling is not None:
batch_size, k, _ = draft_probs.shape
uniform_samples = self._create_uniform_samples(
seeded_seqs, batch_size, k, draft_probs.device)
output_token_ids, accepted_token_num, emitted_token_num \
= chain_speculative_sampling(
draft_probs, draft_token_ids, uniform_samples,
target_with_bonus_probs)

(output_token_ids, accepted_token_num,
emitted_token_num) = chain_speculative_sampling(
draft_probs,
draft_token_ids,
target_with_bonus_probs,
)

# num_emitted_tokens returned by flashinfer
# does not include the bonus token
Expand Down
52 changes: 13 additions & 39 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from collections.abc import Iterator
from dataclasses import dataclass
from importlib.util import find_spec
Expand All @@ -24,7 +23,6 @@
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
Expand All @@ -33,6 +31,10 @@
else:
flashinfer_top_k_top_p_sampling = None

from vllm.logger import init_logger

logger = init_logger(__name__)


def get_sampler() -> torch.nn.Module:
if envs.VLLM_USE_V1:
Expand Down Expand Up @@ -545,38 +547,15 @@ def _multinomial(
def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
top_ks = top_ks.repeat_interleave(num_samples)
top_ps = top_ps.repeat_interleave(num_samples)
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if seq_groups is None:
uniform_samples.uniform_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
uniform_samples[:, sample_idx:sample_idx +
stride].uniform_(generator=seq_group.generator)
sample_idx += stride
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
batch_next_token_ids = flashinfer_top_k_top_p_sampling(
probs,
uniform_samples,
top_ks,
top_ps,
)
if not success.all():
warnings.warn("FlashInfer rejection sampling failed, fallback.",
stacklevel=1)
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0])
return batch_next_token_ids.view(-1, num_samples)


Expand Down Expand Up @@ -712,19 +691,14 @@ def _sample_with_torch(
seq_groups)

if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_n_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_n_in_batch,
seq_groups=seq_groups_arg)
logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")

multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_n_in_batch,
seq_groups=seq_groups_arg)

if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
Expand Down
56 changes: 16 additions & 40 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,10 @@ def __init__(self):
if current_platform.is_cuda():
if is_flashinfer_available:
flashinfer_version = flashinfer.__version__
if flashinfer_version >= "0.2.3":
# FIXME(DefTruth): Currently, we have errors when using
# FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
# workaround, we disable FlashInfer for top-p & top-k
# sampling by default while FlashInfer>=v0.2.3.
# The sampling API removes the success return value
# of all sampling API, which is not compatible with
# earlier design.
# https://github.com/flashinfer-ai/flashinfer/releases/
# tag/v0.2.3
logger.info(
"Currently, FlashInfer top-p & top-k sampling sampler "
"is disabled because FlashInfer>=v0.2.3 is not "
"backward compatible. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling.")
if flashinfer_version < "0.2.3":
logger.warning(
"FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation.")
self.forward = self.forward_native
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
Expand Down Expand Up @@ -106,6 +95,11 @@ def forward_cuda(
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return random_sample(probs, generators)
if generators:
logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(probs, k, p, generators)

def forward_tpu(
Expand Down Expand Up @@ -280,36 +274,18 @@ def flashinfer_sample(
the synchronization overhead.
"""
assert not (k is None and p is None)
max_top_k_round = 32
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if len(generators) != batch_size:
uniform_samples.uniform_()
if generators:
for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator)

if k is None:
# Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True)
elif p is None:
# Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids, success = (
flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, uniform_samples, k, p, deterministic=True))

# NOTE: CPU-GPU synchronization happens here.
if not success.all():
if k is not None:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if p is not None:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True)
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, k, p, deterministic=True))

return next_token_ids.view(-1)