Skip to content

Commit db70231

Browse files
committed
Update input batch.
Signed-off-by: Carol Zheng <[email protected]>
1 parent d29124c commit db70231

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,18 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
12961296
"Hybrid models with more than one KV cache type are not "
12971297
"supported yet.")
12981298

1299+
if kv_cache_config.kv_cache_groups[
1300+
0].kv_cache_spec.block_size != self.block_size:
1301+
self.input_batch = InputBatch(
1302+
max_num_reqs=self.max_num_reqs,
1303+
max_model_len=self.max_model_len,
1304+
max_num_batched_tokens=self.max_num_tokens,
1305+
device=self.device,
1306+
pin_memory=self.pin_memory,
1307+
vocab_size=self.model_config.get_vocab_size(),
1308+
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
1309+
block_size,
1310+
)
12991311
# Verify dtype compatibility between block_table_cpu and input_batch
13001312
assert self.block_table_cpu.dtype == self.input_batch.block_table[
13011313
0].get_cpu_tensor().dtype

0 commit comments

Comments
 (0)