File tree Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -80,15 +80,20 @@ def _allocate_kv_cache(
80
80
)
81
81
except (AttributeError , NotImplementedError ):
82
82
kv_cache_stride_order = tuple (range (len (kv_cache_generic_shape )))
83
- kv_cache_shape = tuple (kv_cache_generic_shape [i ]
84
- for i in kv_cache_stride_order )
83
+
84
+ # The allocation respects the backend-defined stride order to ensure
85
+ # the semantic remains consistent for each backend. We first obtain the
86
+ # generic kv cache shape and then permute it according to the stride
87
+ # order which could result in a non-contiguous tensor.
88
+ kv_cache_allocation_shape = tuple (kv_cache_generic_shape [i ]
89
+ for i in kv_cache_stride_order )
85
90
86
91
for _ in range (self .num_attention_layers ):
87
92
# null block in CpuGpuBlockAllocator requires at least that
88
93
# block to be zeroed-out.
89
94
# We zero-out everything for simplicity.
90
95
layer_kv_cache = torch .zeros (
91
- kv_cache_shape ,
96
+ kv_cache_allocation_shape ,
92
97
dtype = self .dtype ,
93
98
pin_memory = pin_memory ,
94
99
device = device ).permute (* kv_cache_stride_order )
You can’t perform that action at this time.
0 commit comments