Skip to content

Commit 0d71400

Browse files
committed
rebase sync
Signed-off-by: NickLucche <[email protected]>
1 parent ca564f5 commit 0d71400

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

vllm/v1/worker/tpu_model_runner.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -688,20 +688,19 @@ def execute_model(
688688
hidden_states = self.model(
689689
input_ids=input_ids,
690690
positions=self.position_ids,
691-
kv_caches=self.kv_caches,
692691
inputs_embeds=inputs_embeds,
693692
)
694693
# Temporary debug pathway for sampling.
695694
if self._disable_sampler:
696-
selected_token_ids = self.model.compute_logits_no_sampler(
695+
selected_token_ids = self.compute_logits_no_sampler(
697696
hidden_states, logits_indices)
698697
else:
699698
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
700699
# are copied to device in chunks of pre-compiled padded shape to
701700
# avoid recompilations.
702701
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
703702
from_input_batch(self.input_batch, logits_indices)
704-
selected_token_ids = self.model.sample_from_hidden(
703+
selected_token_ids = self.sample_from_hidden(
705704
hidden_states, tpu_sampling_metadata)
706705
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
707706

@@ -902,13 +901,11 @@ def capture_model(self) -> None:
902901
xm.mark_step()
903902
if self._disable_sampler:
904903
# Compile no sampler path for debugging performance
905-
out = self.model.compute_logits_no_sampler(
906-
dummy_hidden, indices)
904+
out = self.compute_logits_no_sampler(dummy_hidden, indices)
907905
else:
908906
sampling_meta = TPUSupportedSamplingMetadata.\
909907
from_input_batch(self.input_batch, indices)
910-
out = self.model.sample_from_hidden(
911-
dummy_hidden, sampling_meta)
908+
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
912909
out = out.cpu()
913910
# Requests can't be more than tokens. But do compile for the
914911
# next bigger value in case num_tokens uses bucketed padding.
@@ -980,9 +977,9 @@ def sample_from_hidden(
980977
sampling_metadata: TPUSupportedSamplingMetadata,
981978
) -> torch.Tensor:
982979
"""
983-
Sample with xla-friendly function. This function is to be traced
984-
separately for lighter compilation overhead.
985-
"""
980+
Sample with xla-friendly function. This function is to be traced
981+
separately for lighter compilation overhead.
982+
"""
986983
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
987984
sample_hidden_states = \
988985
hidden_states[sampling_metadata.indices_do_sample]
@@ -1004,7 +1001,6 @@ def sample(
10041001
sample(logits, sampling_metadata).sampled_token_ids)
10051002
return out_tokens
10061003

1007-
10081004
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
10091005
def compute_logits_no_sampler(
10101006
self, hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)