diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 9235afce1..079ab38c3 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -1151,7 +1151,7 @@ def chain_speculative_sampling( Shape: ``(batch_size, num_speculate_tokens, vocab_size)`` draft_token_ids: torch.Tensor The draft model's generated token indices. - Shape: ``(batch_size, num_specutate_tokens)`` + Shape: ``(batch_size, num_speculate_tokens)`` target_probs: torch.Tensor Expected to be uniformly distributed in ``[0, 1)``. target_probs: torch.Tensor @@ -1183,7 +1183,7 @@ def chain_speculative_sampling( Compared to input :attr:`draft_token_ids`, the output tensor has an additional token index at the end for the final token, if all previous tokens are accepted, another "bonus" token will be sampled from the target model's probability. - Shape: (batch_size, num_specutate_tokens + 1) + Shape: (batch_size, num_speculate_tokens + 1) output_accepted_token_num: torch.Tensor The number of tokens that can be accepted if each token is considered independently for each request. This metric does not consider the fact that rejection sampling will stop at the first token that does not