@@ -47,6 +47,11 @@ def forward(
47
47
logits = self .apply_penalties (logits , sampling_metadata )
48
48
# Sample the next token.
49
49
sampled = self .sample (logits , sampling_metadata )
50
+ # Convert sampled token ids to int64 (long) type to ensure compatibility
51
+ # with subsequent operations that may use these values as indices.
52
+ # This conversion is necessary because FlashInfer sampling operations
53
+ # return int32 (while PyTorch argmax and topk return int64).
54
+ sampled = sampled .long ()
50
55
51
56
# Gather the logprobs of the topk and sampled token (if requested).
52
57
# Get logprobs and rank tensors (if requested)
@@ -139,19 +144,21 @@ def gather_logprobs(
139
144
or sampled tokens (if sampled
140
145
logprobs); 1D token ID tensor
141
146
with (num tokens) elements
147
+ Must be int64.
142
148
143
149
Returns:
144
150
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
145
151
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
146
152
Sampled token rank tensor, (num tokens)
147
153
"""
154
+ assert token_ids .dtype == torch .int64
148
155
# Find the topK values.
149
156
topk_logprobs , topk_indices = torch .topk (logprobs ,
150
157
num_logprobs ,
151
158
dim = - 1 )
152
159
153
160
# Get with the logprob of the prompt or sampled token.
154
- token_ids = token_ids .unsqueeze (- 1 ). to ( torch . long )
161
+ token_ids = token_ids .unsqueeze (- 1 )
155
162
token_logprobs = logprobs .gather (- 1 , token_ids )
156
163
157
164
# Compute the ranks of the actual token.
0 commit comments