Skip to content

Commit af552bb

Browse files
committed
update return type
Signed-off-by: NickLucche <[email protected]>
1 parent ac23eb7 commit af552bb

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/v1/worker/tpu_model_runner.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,8 @@ def execute_model(
787787
logits = self.structured_decode(require_struct_decoding,
788788
grammar_bitmask_padded, logits,
789789
arange)
790-
selected_token_ids, logprobs = self.sample_from_logits(logits, tpu_sampling_metadata)
790+
selected_token_ids, logprobs = self.sample_from_logits(
791+
logits, tpu_sampling_metadata)
791792
# Remove padding on cpu and keep dynamic op outside of xla graph.
792793
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
793794
logprobs_lists = logprobs.tolists() \
@@ -1254,7 +1255,8 @@ def compute_logits(self,
12541255
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
12551256
def sample_from_logits(
12561257
self, logits: torch.Tensor,
1257-
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
1258+
sampling_metadata: TPUSupportedSamplingMetadata) -> \
1259+
tuple[torch.Tensor, Optional[LogprobsTensors]]:
12581260
"""
12591261
Sample with xla-friendly function. This function is to be traced
12601262
separately from `forward` for lighter compilation overhead.

0 commit comments

Comments
 (0)