Skip to content

Commit c845f81

Browse files
committed
Reshape_and_cache_flash kernel to be kv-cache layout aware.
Signed-off-by: shuw <[email protected]>
1 parent bb103b2 commit c845f81

File tree

5 files changed

+39
-20
lines changed

5 files changed

+39
-20
lines changed

csrc/cache.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
2929
torch::Tensor& value_cache,
3030
torch::Tensor& slot_mapping,
3131
const std::string& kv_cache_dtype,
32-
torch::Tensor& k_scale, torch::Tensor& v_scale);
32+
torch::Tensor& k_scale, torch::Tensor& v_scale,
33+
const bool is_NHD = true);
3334

3435
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
3536
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
@@ -45,4 +46,4 @@ void gather_cache(
4546
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
4647
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
4748
torch::Tensor const& cu_seq_lens, // [BATCH+1]
48-
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
49+
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);

csrc/cache_kernels.cu

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,14 @@ template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
265265
__global__ void reshape_and_cache_flash_kernel(
266266
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
267267
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
268-
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
269-
// head_size]
270-
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
271-
// head_size]
268+
cache_t* __restrict__ key_cache, cache_t* __restrict__ value_cache,
272269
const int64_t* __restrict__ slot_mapping, // [num_tokens]
273270
const int block_stride, const int key_stride, const int value_stride,
274271
const int num_heads, const int head_size, const int block_size,
275-
const float* k_scale, const float* v_scale) {
272+
const float* k_scale, const float* v_scale, const bool is_NHD) {
273+
// For key/value_cache layout:
274+
// - NHD: [num_blocks, block_size, num_heads, head_size]
275+
// - HND: [num_blocks, num_heads, block_size, head_size]
276276
const int64_t token_idx = blockIdx.x;
277277
const int64_t slot_idx = slot_mapping[token_idx];
278278
// NOTE: slot_idx can be -1 if the token is padded
@@ -287,9 +287,12 @@ __global__ void reshape_and_cache_flash_kernel(
287287
const int64_t src_value_idx = token_idx * value_stride + i;
288288
const int head_idx = i / head_size;
289289
const int head_offset = i % head_size;
290-
const int64_t tgt_key_value_idx = block_idx * block_stride +
291-
block_offset * num_heads * head_size +
292-
head_idx * head_size + head_offset;
290+
const int64_t tgt_key_value_idx =
291+
block_idx * block_stride +
292+
(is_NHD ? block_offset * num_heads * head_size + head_idx * head_size +
293+
head_offset
294+
: head_idx * block_size * head_size + block_offset * head_size +
295+
head_offset);
293296
scalar_t tgt_key = key[src_key_idx];
294297
scalar_t tgt_value = value[src_value_idx];
295298
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
@@ -416,7 +419,7 @@ void reshape_and_cache_flash(
416419
value_cache, // [num_blocks, block_size, num_heads, head_size]
417420
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
418421
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
419-
torch::Tensor& v_scale) {
422+
torch::Tensor& v_scale, const bool is_NHD) {
420423
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
421424
// slot_mapping.size(0) because of padding for CUDA graphs.
422425
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
@@ -427,10 +430,14 @@ void reshape_and_cache_flash(
427430
// before padding.
428431
// For compatibility with both cases, we use slot_mapping.size(0) as the
429432
// number of tokens.
430-
int num_tokens = slot_mapping.size(0);
433+
// For key/value_cache layout:
434+
// - NHD: [num_blocks, block_size, num_heads, head_size]
435+
// - HND: [num_blocks, num_heads, block_size, head_size]
436+
431437
int num_heads = key.size(1);
438+
int num_tokens = slot_mapping.size(0);
432439
int head_size = key.size(2);
433-
int block_size = key_cache.size(1);
440+
int block_size = is_NHD ? key_cache.size(1) : key_cache.size(2);
434441

435442
int key_stride = key.stride(0);
436443
int value_stride = value.stride(0);

csrc/torch_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
570570
" Tensor! value_cache,"
571571
" Tensor slot_mapping,"
572572
" str kv_cache_dtype,"
573-
" Tensor k_scale, Tensor v_scale) -> ()");
573+
" Tensor k_scale, Tensor v_scale,"
574+
" bool is_NHD) -> ()");
574575
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
575576
&reshape_and_cache_flash);
576577

tests/kernels/test_cache.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NUM_HEADS = [8] # Arbitrary values for testing
1717
HEAD_SIZES = [64, 80, 120, 256]
1818
BLOCK_SIZES = [8, 16, 32]
19+
CACHE_LAYOUTS = ["NHD", "HND"]
1920

2021
# Parameters for MLA tests.
2122
KV_LORA_RANKS = [512]
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
220221
@pytest.mark.parametrize("seed", SEEDS)
221222
@pytest.mark.parametrize("device", CUDA_DEVICES)
222223
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
224+
@pytest.mark.parametrize("kv_layout", CACHE_LAYOUTS)
223225
@torch.inference_mode()
224226
def test_reshape_and_cache_flash(
225227
kv_cache_factory_flashinfer,
@@ -232,6 +234,7 @@ def test_reshape_and_cache_flash(
232234
seed: int,
233235
device: str,
234236
kv_cache_dtype: str,
237+
kv_layout: str,
235238
) -> None:
236239
current_platform.seed_everything(seed)
237240
torch.set_default_device(device)
@@ -242,7 +245,7 @@ def test_reshape_and_cache_flash(
242245
slot_mapping = torch.tensor(slot_mapping_lst,
243246
dtype=torch.long,
244247
device=device)
245-
248+
is_NHD = kv_layout == "NHD"
246249
qkv = torch.randn(num_tokens,
247250
3,
248251
num_heads,
@@ -261,6 +264,7 @@ def test_reshape_and_cache_flash(
261264
kv_cache_dtype,
262265
dtype,
263266
device=device,
267+
is_NHD=is_NHD,
264268
)
265269
key_cache, value_cache = key_caches[0].contiguous(
266270
), value_caches[0].contiguous()
@@ -285,10 +289,11 @@ def test_reshape_and_cache_flash(
285289
# Call the reshape_and_cache kernel.
286290
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
287291
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
288-
k_scale, v_scale),
292+
k_scale, v_scale, is_NHD),
289293
cond=(head_size == HEAD_SIZES[0]))
290294
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
291-
slot_mapping, kv_cache_dtype, k_scale, v_scale)
295+
slot_mapping, kv_cache_dtype, k_scale, v_scale,
296+
is_NHD)
292297

293298
if kv_cache_dtype == "fp8":
294299
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
@@ -310,8 +315,12 @@ def test_reshape_and_cache_flash(
310315
for i in range(num_tokens):
311316
block_idx = block_indicies_lst[i]
312317
block_offset = block_offsets_lst[i]
313-
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
314-
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
318+
if is_NHD:
319+
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
320+
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
321+
else:
322+
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
323+
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
315324

316325
if kv_cache_dtype == "fp8":
317326
torch.testing.assert_close(result_key_cache,

vllm/_custom_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1272,11 +1272,12 @@ def reshape_and_cache_flash(
12721272
kv_cache_dtype: str,
12731273
k_scale: torch.Tensor,
12741274
v_scale: torch.Tensor,
1275+
is_NHD: bool = True,
12751276
) -> None:
12761277
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
12771278
value_cache, slot_mapping,
12781279
kv_cache_dtype, k_scale,
1279-
v_scale)
1280+
v_scale, is_NHD)
12801281

12811282

12821283
def concat_and_cache_mla(

0 commit comments

Comments
 (0)