Skip to content

Commit d14da5d

Browse files
committed
update return type
Signed-off-by: NickLucche <[email protected]>
1 parent 2e0f7cc commit d14da5d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,8 @@ def execute_model(
786786
logits = self.structured_decode(require_struct_decoding,
787787
grammar_bitmask_padded, logits,
788788
arange)
789-
selected_token_ids, logprobs = self.sample_from_logits(logits, tpu_sampling_metadata)
789+
selected_token_ids, logprobs = self.sample_from_logits(
790+
logits, tpu_sampling_metadata)
790791
# Remove padding on cpu and keep dynamic op outside of xla graph.
791792
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
792793
logprobs_lists = logprobs.tolists() \
@@ -1253,7 +1254,8 @@ def compute_logits(self,
12531254
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
12541255
def sample_from_logits(
12551256
self, logits: torch.Tensor,
1256-
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
1257+
sampling_metadata: TPUSupportedSamplingMetadata) -> \
1258+
tuple[torch.Tensor, Optional[LogprobsTensors]]:
12571259
"""
12581260
Sample with xla-friendly function. This function is to be traced
12591261
separately from `forward` for lighter compilation overhead.

0 commit comments

Comments
 (0)