Skip to content

[Bugfix][TPU] Fix tpu model runner testcase failure #18810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[0],
block_ids=[[0]], # block_ids should be list[list[int]]
num_computed_tokens=0,
lora_request=None,
))
Expand Down Expand Up @@ -112,14 +112,35 @@ def _is_req_added(model_runner, req_id: str) -> bool:


def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
"""Check if the request state block IDs match the block table.
This function handles both legacy BlockTable and new MultiGroupBlockTable
structures for backward compatibility.
"""

req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
multi_group_block_table = model_runner.input_batch.block_table
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):

# Access the first block table from MultiGroupBlockTable
# This is safe since we currently only use single KV cache groups
block_table = multi_group_block_table[0]

# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
# Extract the first group's block IDs
if isinstance(req_state.block_ids[0], list):
# New format: list[list[int]] - extract first group
req_block_ids = req_state.block_ids[0]
else:
# Legacy format: list[int] - use directly
req_block_ids = req_state.block_ids

if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
return False

num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()
block_table_values = block_table.block_table_np[req_index, :num_blocks]
return (block_table_values == req_block_ids).all()


def test_update_states_new_request(model_runner):
Expand Down
35 changes: 24 additions & 11 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,21 @@ def __init__(
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# self.input_batch: InputBatch # Persistent batch.

# Request states.
self.requests: dict[str, CachedRequestState] = {}

# Initialize input batch early to avoid AttributeError in _update_states
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.block_size,
)

# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# Sometimes the numpy op is faster so we create both.
Expand Down Expand Up @@ -1286,16 +1296,19 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"Hybrid models with more than one KV cache type are not "
"supported yet.")

self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
)
if kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size != self.block_size:
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
)
# Verify dtype compatibility between block_table_cpu and input_batch
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype

Expand Down