Skip to content

Commit 9a5040b

Browse files
author
Bowen Wang
committed
[Test] Add tests for FlashInfer sampler
Signed-off-by: Bowen Wang <abmfy.icloud.com>
1 parent 2d22d0c commit 9a5040b

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

tests/v1/sample/test_topk_topp_sampler.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,26 @@
22
import torch
33
from torch import Generator
44

5-
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
5+
import pytest
6+
7+
from vllm.platforms import current_platform
8+
from vllm.v1.sample.ops.topk_topp_sampler import (
9+
is_flashinfer_available,
10+
apply_top_k_top_p,
11+
)
12+
13+
from flashinfer.sampling import (
14+
top_k_renorm_probs,
15+
top_p_renorm_probs,
16+
)
617

718
DEVICE = "cuda"
819

920
BATCH_SIZE = 1024
1021
VOCAB_SIZE = 128 * 1024
1122

23+
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
24+
1225

1326
def test_topk_impl_equivalance():
1427

@@ -35,3 +48,62 @@ def test_topk_impl_equivalance():
3548
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
3649

3750
assert torch.allclose(result1, result2)
51+
52+
def test_flashinfer_sampler():
53+
'''
54+
This test verifies that the FlashInfer top-k and top-p sampling
55+
implementation produces the same results as the Python implementation.
56+
57+
NOTE: FlashInfer did not directly expose an interface for fused top-k and
58+
top-p prob renorm (it did provide fused sampling but we cannot compare
59+
sampling results due to randomness), so we will compare the probability
60+
renormed consequently by top-k and then top-p of FlashInfer implementation.
61+
'''
62+
63+
if not FLASHINFER_ENABLED:
64+
pytest.skip("FlashInfer not installed or not available on this platform.")
65+
66+
with torch.device(DEVICE):
67+
generator = Generator(device=DEVICE).manual_seed(42)
68+
69+
# Generate random logits
70+
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
71+
72+
# Generate various top-k and top-p values
73+
k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
74+
p_values = torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
75+
76+
# Sometimes disable top-k (k=vocab_size)
77+
k_values.masked_fill_(
78+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
79+
VOCAB_SIZE)
80+
81+
# Sometimes disable top-p (p=1.0)
82+
p_values.masked_fill_(
83+
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
84+
1.0)
85+
86+
python_logits = apply_top_k_top_p(
87+
logits=logits.clone(),
88+
k=k_values,
89+
p=p_values,
90+
)
91+
python_probs = torch.softmax(python_logits, dim=-1)
92+
93+
import IPython
94+
IPython.embed()
95+
96+
# FlashInfer only exposed renorm interfaces for probs so convert first
97+
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
98+
flashinfer_probs = top_k_renorm_probs(
99+
probs=flashinfer_probs,
100+
top_k=k_values,
101+
)
102+
flashinfer_probs = top_p_renorm_probs(
103+
probs=flashinfer_probs,
104+
top_p=p_values,
105+
)
106+
107+
# Compare the results
108+
assert torch.allclose(python_probs, flashinfer_probs, atol=1e-5), \
109+
"FlashInfer and Python sampling implementations do not match!"

0 commit comments

Comments
 (0)