2
2
import torch
3
3
from torch import Generator
4
4
5
- from vllm .v1 .sample .ops .topk_topp_sampler import apply_top_k_top_p
5
+ import pytest
6
+
7
+ from vllm .platforms import current_platform
8
+ from vllm .v1 .sample .ops .topk_topp_sampler import (
9
+ is_flashinfer_available ,
10
+ apply_top_k_top_p ,
11
+ )
12
+
13
+ from flashinfer .sampling import (
14
+ top_k_renorm_probs ,
15
+ top_p_renorm_probs ,
16
+ )
6
17
7
18
DEVICE = "cuda"
8
19
9
20
BATCH_SIZE = 1024
10
21
VOCAB_SIZE = 128 * 1024
11
22
23
+ FLASHINFER_ENABLED = current_platform .is_cuda () and is_flashinfer_available
24
+
12
25
13
26
def test_topk_impl_equivalance ():
14
27
@@ -35,3 +48,62 @@ def test_topk_impl_equivalance():
35
48
result2 = apply_top_k_top_p (logits = logits .clone (), k = k , p = no_op_top_p )
36
49
37
50
assert torch .allclose (result1 , result2 )
51
+
52
+ def test_flashinfer_sampler ():
53
+ '''
54
+ This test verifies that the FlashInfer top-k and top-p sampling
55
+ implementation produces the same results as the Python implementation.
56
+
57
+ NOTE: FlashInfer did not directly expose an interface for fused top-k and
58
+ top-p prob renorm (it did provide fused sampling but we cannot compare
59
+ sampling results due to randomness), so we will compare the probability
60
+ renormed consequently by top-k and then top-p of FlashInfer implementation.
61
+ '''
62
+
63
+ if not FLASHINFER_ENABLED :
64
+ pytest .skip ("FlashInfer not installed or not available on this platform." )
65
+
66
+ with torch .device (DEVICE ):
67
+ generator = Generator (device = DEVICE ).manual_seed (42 )
68
+
69
+ # Generate random logits
70
+ logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
71
+
72
+ # Generate various top-k and top-p values
73
+ k_values = torch .randint (1 , 1000 , (BATCH_SIZE ,), generator = generator )
74
+ p_values = torch .rand ((BATCH_SIZE ,), generator = generator ) * 0.5 + 0.5 # range in [0.5, 1.0]
75
+
76
+ # Sometimes disable top-k (k=vocab_size)
77
+ k_values .masked_fill_ (
78
+ torch .randint (0 , 2 , (BATCH_SIZE ,), generator = generator , dtype = torch .bool ),
79
+ VOCAB_SIZE )
80
+
81
+ # Sometimes disable top-p (p=1.0)
82
+ p_values .masked_fill_ (
83
+ torch .randint (0 , 2 , (BATCH_SIZE ,), generator = generator , dtype = torch .bool ),
84
+ 1.0 )
85
+
86
+ python_logits = apply_top_k_top_p (
87
+ logits = logits .clone (),
88
+ k = k_values ,
89
+ p = p_values ,
90
+ )
91
+ python_probs = torch .softmax (python_logits , dim = - 1 )
92
+
93
+ import IPython
94
+ IPython .embed ()
95
+
96
+ # FlashInfer only exposed renorm interfaces for probs so convert first
97
+ flashinfer_probs = torch .softmax (logits .clone (), dim = - 1 )
98
+ flashinfer_probs = top_k_renorm_probs (
99
+ probs = flashinfer_probs ,
100
+ top_k = k_values ,
101
+ )
102
+ flashinfer_probs = top_p_renorm_probs (
103
+ probs = flashinfer_probs ,
104
+ top_p = p_values ,
105
+ )
106
+
107
+ # Compare the results
108
+ assert torch .allclose (python_probs , flashinfer_probs , atol = 1e-5 ), \
109
+ "FlashInfer and Python sampling implementations do not match!"
0 commit comments