Skip to content

Commit 68a0378

Browse files
authored
[API] Fix top_k_top_p_sampling_from_logits param typo (#875)
Resolves #873
1 parent 78dde79 commit 68a0378

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

flashinfer/sampling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def min_p_sampling_from_probs(
703703

704704

705705
def top_k_top_p_sampling_from_logits(
706-
probs: torch.Tensor,
706+
logits: torch.Tensor,
707707
uniform_samples: torch.Tensor,
708708
top_k: Union[torch.Tensor, int],
709709
top_p: Union[torch.Tensor, float],
@@ -798,13 +798,13 @@ def top_k_top_p_sampling_from_logits(
798798
top_p_sampling_from_probs
799799
"""
800800
if filter_apply_order == "top_k_first":
801-
masked_logits = top_k_mask_logits(probs, top_k)
801+
masked_logits = top_k_mask_logits(logits, top_k)
802802
probs = torch.softmax(masked_logits, dim=-1)
803803
return top_p_sampling_from_probs(
804804
probs, uniform_samples, top_p, deterministic, check_nan=check_nan
805805
)
806806
elif filter_apply_order == "joint":
807-
probs = torch.softmax(probs, dim=-1)
807+
probs = torch.softmax(logits, dim=-1)
808808
if check_nan:
809809
if torch.any(torch.isnan(probs)):
810810
raise ValueError("Input probs contains NaN.")

0 commit comments

Comments
 (0)