Skip to content

Commit 4049723

Browse files
WoosukKwonshreyankg
authored andcommitted
[V1] Ensure using int64 for sampled token ids (vllm-project#15065)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 3a710a7 commit 4049723

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vllm/v1/sample/sampler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def forward(
4747
logits = self.apply_penalties(logits, sampling_metadata)
4848
# Sample the next token.
4949
sampled = self.sample(logits, sampling_metadata)
50+
# Convert sampled token ids to int64 (long) type to ensure compatibility
51+
# with subsequent operations that may use these values as indices.
52+
# This conversion is necessary because FlashInfer sampling operations
53+
# return int32 (while PyTorch argmax and topk return int64).
54+
sampled = sampled.long()
5055

5156
# Gather the logprobs of the topk and sampled token (if requested).
5257
# Get logprobs and rank tensors (if requested)
@@ -139,19 +144,21 @@ def gather_logprobs(
139144
or sampled tokens (if sampled
140145
logprobs); 1D token ID tensor
141146
with (num tokens) elements
147+
Must be int64.
142148
143149
Returns:
144150
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
145151
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
146152
Sampled token rank tensor, (num tokens)
147153
"""
154+
assert token_ids.dtype == torch.int64
148155
# Find the topK values.
149156
topk_logprobs, topk_indices = torch.topk(logprobs,
150157
num_logprobs,
151158
dim=-1)
152159

153160
# Get with the logprob of the prompt or sampled token.
154-
token_ids = token_ids.unsqueeze(-1).to(torch.long)
161+
token_ids = token_ids.unsqueeze(-1)
155162
token_logprobs = logprobs.gather(-1, token_ids)
156163

157164
# Compute the ranks of the actual token.

0 commit comments

Comments
 (0)