@@ -89,18 +89,18 @@ def forward_cuda(
89
89
p : Optional [torch .Tensor ],
90
90
) -> torch .Tensor :
91
91
"""More optimized implementation for top-k and top-p sampling."""
92
- probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
93
92
if k is None and p is None :
94
93
# We prefer `random_sample` over `flashinfer_sample` when sorting is
95
94
# not needed. This is because `random_sample` does not require
96
95
# CPU-GPU synchronization while `flashinfer_sample` does.
96
+ probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
97
97
return random_sample (probs , generators )
98
98
if generators :
99
99
logger .warning ("FlashInfer 0.2.3+ does not support "
100
100
"per-request generators. Falling back to "
101
101
"PyTorch-native implementation." )
102
102
return self .forward_native (logits , generators , k , p )
103
- return flashinfer_sample (probs , k , p , generators )
103
+ return flashinfer_sample (logits , k , p , generators )
104
104
105
105
def forward_tpu (
106
106
self ,
@@ -254,7 +254,7 @@ def random_sample(
254
254
255
255
256
256
def flashinfer_sample (
257
- probs : torch .Tensor ,
257
+ logits : torch .Tensor ,
258
258
k : Optional [torch .Tensor ],
259
259
p : Optional [torch .Tensor ],
260
260
generators : dict [int , torch .Generator ],
@@ -264,7 +264,7 @@ def flashinfer_sample(
264
264
Statistically, this function is equivalent to the `random_sample` function.
265
265
However, this function is faster because it avoids sorting the logits tensor
266
266
via rejection sampling.
267
-
267
+
268
268
NOTE: The outputs of this function do not necessarily match the outputs of
269
269
the `random_sample` function. It only guarantees that the outputs are
270
270
statistically equivalent.
@@ -274,18 +274,19 @@ def flashinfer_sample(
274
274
the synchronization overhead.
275
275
"""
276
276
assert not (k is None and p is None )
277
-
278
277
if k is None :
279
278
# Top-p only.
279
+ probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
280
280
next_token_ids = flashinfer .sampling .top_p_sampling_from_probs (
281
281
probs , p , deterministic = True )
282
282
elif p is None :
283
283
# Top-k only.
284
+ probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
284
285
next_token_ids = flashinfer .sampling .top_k_sampling_from_probs (
285
286
probs , k , deterministic = True )
286
287
else :
287
288
# Both top-k and top-p.
288
- next_token_ids = ( flashinfer .sampling .top_k_top_p_sampling_from_probs (
289
- probs , k , p , deterministic = True ) )
289
+ next_token_ids = flashinfer .sampling .top_k_top_p_sampling_from_logits (
290
+ logits , k , p , deterministic = True )
290
291
291
292
return next_token_ids .view (- 1 )
0 commit comments