Skip to content

Commit 61ec7f8

Browse files
committed
precompiling no sampling graph
Signed-off-by: NickLucche <[email protected]>
1 parent 8bfcdc3 commit 61ec7f8

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

vllm/envs.py

-3
Original file line numberDiff line numberDiff line change
@@ -690,15 +690,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
690690
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
691691
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
692692

693-
<<<<<<< HEAD
694693
# Allow use of DeepGemm kernels for fused moe ops.
695694
"VLLM_USE_DEEP_GEMM":
696695
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
697-
=======
698696
# Disable sampler path for debugging performance.
699697
"VLLM_TPU_DISABLE_SAMPLER_DEBUG":
700698
lambda: os.environ.get("VLLM_TPU_DISABLE_SAMPLER_DEBUG", "0") == "1",
701-
>>>>>>> fdc71ec1a (updates)
702699
}
703700

704701
# end-env-vars-definition

vllm/v1/worker/tpu_model_runner.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
# InputBatch needs to work with sampling tensors greater than padding
9696
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
9797
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
98+
self._disable_sampler = envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG
9899

99100
# Model-related.
100101
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@@ -639,7 +640,7 @@ def execute_model(
639640
num_reqs = self.input_batch.num_reqs
640641

641642
# Temporary debug pathway.
642-
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
643+
if self._disable_sampler:
643644
with set_forward_context(attn_metadata, self.vllm_config):
644645
hidden_states = self.model(
645646
input_ids=input_ids,
@@ -648,7 +649,7 @@ def execute_model(
648649
inputs_embeds=inputs_embeds,
649650
)
650651
selected_token_ids = self.model.compute_logits_no_sampler(
651-
hidden_states, logits_indices, None)
652+
hidden_states, logits_indices)
652653
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
653654
else:
654655
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
@@ -857,18 +858,23 @@ def capture_model(self) -> None:
857858
dtype=self._hidden_states_dtype)
858859
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
859860
while True:
861+
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
862+
num_reqs_to_sample)
860863
indices = torch.zeros(
861864
num_reqs_to_sample,
862865
dtype=torch.int32,
863866
device=device,
864867
)
865868
xm.mark_step()
866-
sampling_meta = TPUSupportedSamplingMetadata.\
867-
from_input_batch(self.input_batch, indices)
868-
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
869-
num_reqs_to_sample)
870-
out = self.model.sample_from_hidden(dummy_hidden,
871-
sampling_meta)
869+
if self._disable_sampler:
870+
# Compile no sampler path for debugging performance
871+
out = self.model.compute_logits_no_sampler(
872+
dummy_hidden, indices)
873+
else:
874+
sampling_meta = TPUSupportedSamplingMetadata.\
875+
from_input_batch(self.input_batch, indices)
876+
out = self.model.sample_from_hidden(
877+
dummy_hidden, sampling_meta)
872878
out = out.cpu()
873879
# Requests can't be more than tokens. But do compile for the
874880
# next bigger value in case num_tokens uses bucketed padding.
@@ -991,13 +997,10 @@ def compute_logits(self,
991997

992998
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
993999
def compute_logits_no_sampler(
994-
self,
995-
hidden_states: torch.Tensor,
996-
logits_indices: torch.Tensor,
997-
sampling_metadata,
998-
) -> Optional[torch.Tensor]:
1000+
self, hidden_states: torch.Tensor,
1001+
logits_indices: torch.Tensor) -> Optional[torch.Tensor]:
9991002
hidden_states = hidden_states[logits_indices]
1000-
logits = self.model.compute_logits(hidden_states, sampling_metadata)
1003+
logits = self.model.compute_logits(hidden_states, None)
10011004
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
10021005
return selected_token_ids
10031006

0 commit comments

Comments
 (0)