Skip to content

Commit ba8952d

Browse files
committed
Improve comments
Signed-off-by: shuw <[email protected]>
1 parent a267159 commit ba8952d

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

vllm/worker/cache_engine.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,20 @@ def _allocate_kv_cache(
8080
)
8181
except (AttributeError, NotImplementedError):
8282
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
83-
kv_cache_shape = tuple(kv_cache_generic_shape[i]
84-
for i in kv_cache_stride_order)
83+
84+
# The allocation respects the backend-defined stride order to ensure
85+
# the semantic remains consistent for each backend. We first obtain the
86+
# generic kv cache shape and then permute it according to the stride
87+
# order which could result in a non-contiguous tensor.
88+
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
89+
for i in kv_cache_stride_order)
8590

8691
for _ in range(self.num_attention_layers):
8792
# null block in CpuGpuBlockAllocator requires at least that
8893
# block to be zeroed-out.
8994
# We zero-out everything for simplicity.
9095
layer_kv_cache = torch.zeros(
91-
kv_cache_shape,
96+
kv_cache_allocation_shape,
9297
dtype=self.dtype,
9398
pin_memory=pin_memory,
9499
device=device).permute(*kv_cache_stride_order)

0 commit comments

Comments
 (0)