16
16
FLASHINFER_ENABLED = current_platform .is_cuda () and is_flashinfer_available
17
17
18
18
19
+ @pytest .fixture (autouse = True )
20
+ def reset_default_device ():
21
+ """
22
+ Explicitly set the default device, which can affect subsequent tests.
23
+ Adding this fixture helps avoid this problem.
24
+ """
25
+ original_device = torch .get_default_device ()
26
+ yield
27
+ torch .set_default_device (original_device )
28
+
29
+
19
30
def test_topk_impl_equivalance ():
20
31
21
- with torch .device (DEVICE ):
22
- generator = Generator (device = DEVICE ).manual_seed (33 )
32
+ torch .set_default_device (DEVICE )
33
+ generator = Generator (device = DEVICE ).manual_seed (33 )
23
34
24
- logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
35
+ logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
25
36
26
- # Random top-k values between 1 and 9.
27
- k = torch .randint (1 , 10 , (BATCH_SIZE , ), generator = generator )
37
+ # Random top-k values between 1 and 9.
38
+ k = torch .randint (1 , 10 , (BATCH_SIZE , ), generator = generator )
28
39
29
- # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
30
- k .masked_fill_ (
31
- torch .randint (0 ,
32
- 2 , (BATCH_SIZE , ),
33
- generator = generator ,
34
- dtype = bool ), VOCAB_SIZE )
40
+ # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
41
+ k .masked_fill_ (
42
+ torch .randint (0 , 2 , (BATCH_SIZE , ), generator = generator , dtype = bool ),
43
+ VOCAB_SIZE )
35
44
36
- # Top-k only implementation
37
- result1 = apply_top_k_top_p (logits = logits .clone (), k = k , p = None )
45
+ # Top-k only implementation
46
+ result1 = apply_top_k_top_p (logits = logits .clone (), k = k , p = None )
38
47
39
- # Top-p + top-k
40
- no_op_top_p = torch .tensor ([1.0 ])
41
- result2 = apply_top_k_top_p (logits = logits .clone (), k = k , p = no_op_top_p )
48
+ # Top-p + top-k
49
+ no_op_top_p = torch .tensor ([1.0 ])
50
+ result2 = apply_top_k_top_p (logits = logits .clone (), k = k , p = no_op_top_p )
42
51
43
- assert torch .allclose (result1 , result2 )
52
+ assert torch .allclose (result1 , result2 )
44
53
45
54
46
55
def test_flashinfer_sampler ():
@@ -58,50 +67,49 @@ def test_flashinfer_sampler():
58
67
pytest .skip (
59
68
"FlashInfer not installed or not available on this platform." )
60
69
61
- with torch .device (DEVICE ):
62
- generator = Generator (device = DEVICE ).manual_seed (42 )
63
-
64
- # Generate random logits
65
- logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
66
-
67
- # Generate various top-k and top-p values
68
- k_values = torch .randint (1 , 1000 , (BATCH_SIZE , ), generator = generator )
69
- p_values = torch .rand (
70
- (BATCH_SIZE , ),
71
- generator = generator ) * 0.5 + 0.5 # range in [0.5, 1.0]
72
-
73
- # Sometimes disable top-k (k=vocab_size)
74
- k_values .masked_fill_ (
75
- torch .randint (0 ,
76
- 2 , (BATCH_SIZE , ),
77
- generator = generator ,
78
- dtype = torch .bool ), VOCAB_SIZE )
79
-
80
- # Sometimes disable top-p (p=1.0)
81
- p_values .masked_fill_ (
82
- torch .randint (0 ,
83
- 2 , (BATCH_SIZE , ),
84
- generator = generator ,
85
- dtype = torch .bool ), 1.0 )
86
-
87
- python_logits = apply_top_k_top_p (
88
- logits = logits .clone (),
89
- k = k_values ,
90
- p = p_values ,
91
- )
92
- python_probs = torch .softmax (python_logits , dim = - 1 )
93
-
94
- # FlashInfer only exposed renorm interfaces for probs so convert first
95
- flashinfer_probs = torch .softmax (logits .clone (), dim = - 1 )
96
- flashinfer_probs = top_k_renorm_probs (
97
- probs = flashinfer_probs ,
98
- top_k = k_values ,
99
- )
100
- flashinfer_probs = top_p_renorm_probs (
101
- probs = flashinfer_probs ,
102
- top_p = p_values ,
103
- )
104
-
105
- # Compare the results
106
- assert torch .allclose (python_probs , flashinfer_probs , atol = 2e-2 ), \
107
- "FlashInfer and Python sampling implementations do not match!"
70
+ torch .set_default_device (DEVICE )
71
+ generator = Generator (device = DEVICE ).manual_seed (42 )
72
+
73
+ # Generate random logits
74
+ logits = torch .rand ((BATCH_SIZE , VOCAB_SIZE ), generator = generator )
75
+
76
+ # Generate various top-k and top-p values
77
+ k_values = torch .randint (1 , 1000 , (BATCH_SIZE , ), generator = generator )
78
+ p_values = torch .rand (
79
+ (BATCH_SIZE , ), generator = generator ) * 0.5 + 0.5 # range in [0.5, 1.0]
80
+
81
+ # Sometimes disable top-k (k=vocab_size)
82
+ k_values .masked_fill_ (
83
+ torch .randint (0 ,
84
+ 2 , (BATCH_SIZE , ),
85
+ generator = generator ,
86
+ dtype = torch .bool ), VOCAB_SIZE )
87
+
88
+ # Sometimes disable top-p (p=1.0)
89
+ p_values .masked_fill_ (
90
+ torch .randint (0 ,
91
+ 2 , (BATCH_SIZE , ),
92
+ generator = generator ,
93
+ dtype = torch .bool ), 1.0 )
94
+
95
+ python_logits = apply_top_k_top_p (
96
+ logits = logits .clone (),
97
+ k = k_values ,
98
+ p = p_values ,
99
+ )
100
+ python_probs = torch .softmax (python_logits , dim = - 1 )
101
+
102
+ # FlashInfer only exposed renorm interfaces for probs so convert first
103
+ flashinfer_probs = torch .softmax (logits .clone (), dim = - 1 )
104
+ flashinfer_probs = top_k_renorm_probs (
105
+ probs = flashinfer_probs ,
106
+ top_k = k_values ,
107
+ )
108
+ flashinfer_probs = top_p_renorm_probs (
109
+ probs = flashinfer_probs ,
110
+ top_p = p_values ,
111
+ )
112
+
113
+ # Compare the results
114
+ assert torch .allclose (python_probs , flashinfer_probs , atol = 2e-2 ), \
115
+ "FlashInfer and Python sampling implementations do not match!"
0 commit comments