Skip to content

Commit 70ad3f9

Browse files
authored
[Bugfix][TPU] Fix V1 TPU worker for sliding window (#16059)
Signed-off-by: Michael Goin <[email protected]>
1 parent d6fc629 commit 70ad3f9

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

vllm/v1/worker/tpu_worker.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.model_executor import set_random_seed
1919
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
2020
from vllm.v1.core.sched.output import SchedulerOutput
21-
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
21+
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
2222
KVCacheSpec)
2323
from vllm.v1.outputs import ModelRunnerOutput
2424
from vllm.v1.utils import bind_kv_cache
@@ -137,7 +137,7 @@ def determine_available_memory(self) -> int:
137137
kv_caches: dict[str, torch.Tensor] = {}
138138
kv_cache_spec = self.model_runner.get_kv_cache_spec()
139139
for layer_name, layer_spec in kv_cache_spec.items():
140-
if isinstance(layer_spec, FullAttentionSpec):
140+
if isinstance(layer_spec, AttentionSpec):
141141
dtype = layer_spec.dtype
142142

143143
# Use an empty tensor instead of `None`` to force Dynamo to pass
@@ -147,7 +147,8 @@ def determine_available_memory(self) -> int:
147147
device=self.device)
148148
kv_caches[layer_name] = tpu_kv_cache
149149
else:
150-
raise NotImplementedError
150+
raise NotImplementedError(
151+
f"Unsupported KV cache spec '{type(layer_spec)}'")
151152

152153
runner_kv_caches: list[torch.Tensor] = []
153154
bind_kv_cache(

0 commit comments

Comments
 (0)