Skip to content

Commit d3f18f3

Browse files
yaochengjiDamonFool
authored andcommitted
[Hardware][TPU] Fix the recompiling issue in logits processor after warmup (vllm-project#14510)
Signed-off-by: Chengji Yao <[email protected]>
1 parent d07b3a1 commit d3f18f3

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

examples/offline_inference/tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
2323
# In real workloads, `enforace_eager` should be `False`.
24-
llm = LLM(model="google/gemma-2b", enforce_eager=True)
24+
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
25+
max_num_batched_tokens=64,
26+
max_num_seqs=4)
2527
outputs = llm.generate(prompts, sampling_params)
2628
for output, answer in zip(outputs, answers):
2729
prompt = output.prompt

vllm/v1/worker/tpu_model_runner.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
401401
self.query_start_loc_np[0] = 0
402402
np.cumsum(num_scheduled_tokens_per_req,
403403
out=self.query_start_loc_np[1:num_reqs + 1])
404+
self.query_start_loc_np[num_reqs + 1:] = 1
404405

405406
self.seq_lens_np[:num_reqs] = (
406407
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
@@ -441,7 +442,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
441442
# partial request, we do so for simplicity. We will ignore the sampled
442443
# token from the partial request.
443444
# TODO: Support prompt logprobs.
444-
logits_indices = query_start_loc[1:] - 1
445+
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
446+
num_reqs, self.max_num_reqs)
447+
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
448+
logits_indices = logits_indices.to(self.device)
445449
return attn_metadata, logits_indices
446450

447451
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
@@ -551,7 +555,6 @@ def execute_model(
551555

552556
# Prepare inputs
553557
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
554-
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
555558

556559
if self.is_multimodal_model:
557560
# NOTE(woosuk): To unify token ids and soft tokens (vision
@@ -579,12 +582,10 @@ def execute_model(
579582
kv_caches=self.kv_caches,
580583
inputs_embeds=inputs_embeds,
581584
)
582-
hidden_states = hidden_states[:total_num_scheduled_tokens]
583585
num_reqs = self.input_batch.num_reqs
584-
logits_indices = logits_indices[:num_reqs]
585-
hidden_states = hidden_states[logits_indices]
586-
logits = self.model.compute_logits(hidden_states, None)
587-
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
586+
selected_token_ids = self.model.compute_logits(hidden_states,
587+
logits_indices, None)
588+
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
588589

589590
# Then, let's update the cache state.
590591
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
@@ -726,12 +727,31 @@ def _dummy_run(
726727

727728
with set_forward_context(attn_metadata, self.vllm_config, 0):
728729
assert self.model is not None
729-
self.model(
730+
hidden_states = self.model(
730731
input_ids=input_ids,
731732
positions=position_ids,
732733
kv_caches=kv_caches,
733734
inputs_embeds=inputs_embeds,
734735
)
736+
num_reqs = _get_padded_num_reqs_with_upper_limit(
737+
64, self.max_num_reqs)
738+
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
739+
# compilation cache size of token_bucket_num multiplied by
740+
# req_bucket_num. This is acceptable, given the graph's relatively
741+
# small size.
742+
while True:
743+
logits_indices = torch.zeros(
744+
num_reqs,
745+
dtype=torch.int32,
746+
device=self.device,
747+
)
748+
torch._dynamo.mark_dynamic(hidden_states, 0)
749+
torch._dynamo.mark_dynamic(logits_indices, 0)
750+
self.model.compute_logits(hidden_states, logits_indices, None)
751+
if num_reqs >= self.max_num_reqs:
752+
break
753+
num_reqs = _get_padded_num_reqs_with_upper_limit(
754+
num_reqs + 1, self.max_num_reqs)
735755

736756
def capture_model(self) -> None:
737757
"""Compile the model."""
@@ -823,13 +843,17 @@ def forward(
823843

824844
return hidden_states
825845

846+
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
826847
def compute_logits(
827848
self,
828849
hidden_states: torch.Tensor,
850+
logits_indices: torch.Tensor,
829851
sampling_metadata,
830852
) -> Optional[torch.Tensor]:
853+
hidden_states = hidden_states[logits_indices]
831854
logits = self.model.compute_logits(hidden_states, sampling_metadata)
832-
return logits
855+
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
856+
return selected_token_ids
833857

834858
def get_multimodal_embeddings(self, *args, **kwargs):
835859
return self.model.get_multimodal_embeddings(*args, **kwargs)
@@ -846,3 +870,8 @@ def _get_padded_token_len(x: int) -> int:
846870
if x <= 16:
847871
return 16
848872
return 1 << (x - 1).bit_length()
873+
874+
875+
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
876+
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
877+
return min(res, upper_limit)

0 commit comments

Comments
 (0)