Skip to content

Commit 49d5fcb

Browse files
committed
precompiling no sampling graph
Signed-off-by: NickLucche <[email protected]>
1 parent a1f45bf commit 49d5fcb

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

vllm/envs.py

-3
Original file line numberDiff line numberDiff line change
@@ -695,15 +695,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
695695
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
696696
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
697697

698-
<<<<<<< HEAD
699698
# Allow use of DeepGemm kernels for fused moe ops.
700699
"VLLM_USE_DEEP_GEMM":
701700
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
702-
=======
703701
# Disable sampler path for debugging performance.
704702
"VLLM_TPU_DISABLE_SAMPLER_DEBUG":
705703
lambda: os.environ.get("VLLM_TPU_DISABLE_SAMPLER_DEBUG", "0") == "1",
706-
>>>>>>> fdc71ec1a (updates)
707704
}
708705

709706
# end-env-vars-definition

vllm/v1/worker/tpu_model_runner.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
# InputBatch needs to work with sampling tensors greater than padding
9999
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
100100
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
101+
self._disable_sampler = envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG
101102

102103
# Model-related.
103104
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@@ -684,7 +685,7 @@ def execute_model(
684685
num_reqs = self.input_batch.num_reqs
685686

686687
# Temporary debug pathway.
687-
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
688+
if self._disable_sampler:
688689
with set_forward_context(attn_metadata, self.vllm_config):
689690
hidden_states = self.model(
690691
input_ids=input_ids,
@@ -693,7 +694,7 @@ def execute_model(
693694
inputs_embeds=inputs_embeds,
694695
)
695696
selected_token_ids = self.model.compute_logits_no_sampler(
696-
hidden_states, logits_indices, None)
697+
hidden_states, logits_indices)
697698
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
698699
else:
699700
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
@@ -899,17 +900,23 @@ def capture_model(self) -> None:
899900
dtype=self._hidden_states_dtype)
900901
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
901902
while True:
903+
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
904+
num_reqs_to_sample)
902905
indices = torch.zeros(
903906
num_reqs_to_sample,
904907
dtype=torch.int32,
905908
device=device,
906909
)
907910
xm.mark_step()
908-
sampling_meta = TPUSupportedSamplingMetadata.\
909-
from_input_batch(self.input_batch, indices)
910-
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
911-
num_reqs_to_sample)
912-
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
911+
if self._disable_sampler:
912+
# Compile no sampler path for debugging performance
913+
out = self.model.compute_logits_no_sampler(
914+
dummy_hidden, indices)
915+
else:
916+
sampling_meta = TPUSupportedSamplingMetadata.\
917+
from_input_batch(self.input_batch, indices)
918+
out = self.model.sample_from_hidden(
919+
dummy_hidden, sampling_meta)
913920
out = out.cpu()
914921
# Requests can't be more than tokens. But do compile for the
915922
# next bigger value in case num_tokens uses bucketed padding.
@@ -1006,6 +1013,16 @@ def sample(
10061013
return out_tokens
10071014

10081015

1016+
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1017+
def compute_logits_no_sampler(
1018+
self, hidden_states: torch.Tensor,
1019+
logits_indices: torch.Tensor) -> Optional[torch.Tensor]:
1020+
hidden_states = hidden_states[logits_indices]
1021+
logits = self.model.compute_logits(hidden_states, None)
1022+
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
1023+
return selected_token_ids
1024+
1025+
10091026
def _get_padded_number(n: int, multiple: int) -> int:
10101027
return ((n + multiple - 1) // multiple) * multiple
10111028

0 commit comments

Comments
 (0)