Skip to content

Allocate kv_cache with stride order #16605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const float* k_scale, const float* v_scale) {
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
Expand All @@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
block_offset * page_stride +
head_idx * head_stride + head_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
Expand Down Expand Up @@ -396,16 +397,16 @@ void reshape_and_cache(
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));

void reshape_and_cache_flash(
Expand All @@ -432,9 +433,11 @@ void reshape_and_cache_flash(
int head_size = key.size(2);
int block_size = key_cache.size(1);

int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = key_cache.stride(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int64_t block_stride = key_cache.stride(0);
int64_t page_stride = key_cache.stride(1);
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));

dim3 grid(num_tokens);
Expand Down
60 changes: 41 additions & 19 deletions tests/kernels/attention/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 120, 256]
BLOCK_SIZES = [8, 16, 32]
CACHE_LAYOUTS = ["NHD", "HND"]

# Parameters for MLA tests.
KV_LORA_RANKS = [512]
Expand Down Expand Up @@ -220,6 +221,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
Expand All @@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
seed: int,
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)

# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens = num_tokens // 2
num_blocks = num_blocks // 2
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)

qkv = torch.randn(num_tokens,
3,
num_heads,
Expand All @@ -261,44 +267,56 @@ def test_reshape_and_cache_flash(
kv_cache_dtype,
dtype,
device=device,
cache_layout=kv_cache_layout,
)
key_cache, value_cache = key_caches[0].contiguous(
), value_caches[0].contiguous()
key_cache, value_cache = key_caches[0], value_caches[0]
del key_caches
del value_caches

k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)

def permute_and_compact(x):
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
return y.contiguous()

key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)

# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
cloned_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache_compact,
v_scale.item(), kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()

cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)

if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
result_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache,
key_cache_compact,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
result_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype)

Expand All @@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
for i in range(num_tokens):
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
if kv_cache_layout == "NHD":
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
else:
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
cloned_value_cache[block_idx, :, block_offset, :] = value[i]

if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache,
Expand All @@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
atol=0.001,
rtol=0.1)
else:
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
torch.testing.assert_close(value_cache_compact, cloned_value_cache)


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def get_kv_cache_shape(
) -> Tuple[int, ...]:
raise NotImplementedError

@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError

@staticmethod
@abstractmethod
def swap_blocks(
Expand Down
30 changes: 25 additions & 5 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import dataclasses
import os
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -48,6 +49,9 @@
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()


class FlashInferBackend(AttentionBackend):

Expand Down Expand Up @@ -80,6 +84,14 @@ def get_kv_cache_shape(
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
assert (cache_layout in ("NHD", "HND"))
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
2, 4)
return stride_order

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
Expand Down Expand Up @@ -188,6 +200,7 @@ def __init__(self, runner):
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = self.runner.vllm_config
self._kv_cache_layout = None

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
Expand All @@ -197,10 +210,15 @@ def _get_workspace_buffer(self):
device=self.runner.device)
return self._workspace_buffer

def get_kv_cache_layout(self):
if self._kv_cache_layout is None:
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
return self._kv_cache_layout

def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
self._get_workspace_buffer(), self.get_kv_cache_layout())
return self._prefill_wrapper

def _get_decode_wrapper(self):
Expand All @@ -213,7 +231,7 @@ def _get_decode_wrapper(self):
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
self.get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

Expand Down Expand Up @@ -274,7 +292,8 @@ def graph_capture_get_metadata_for_batch(
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
self._graph_indices_buffer, _last_page_len_buffer,
self.get_kv_cache_layout(),
use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
Expand Down Expand Up @@ -1005,6 +1024,7 @@ def forward(

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
Expand Down Expand Up @@ -1036,7 +1056,7 @@ def forward(

prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
Expand All @@ -1051,7 +1071,7 @@ def forward(

decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
Expand Down
13 changes: 10 additions & 3 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,21 +762,28 @@ def create_kv_caches_with_random_flash(
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = None,
device: Optional[str] = "cuda",
cache_layout: Optional[str] = "NHD",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)

torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
assert cache_layout in ("NHD", "HND")
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2,
4)

kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i]
for i in stride_order)
scale = head_size**-0.5

key_caches: list[torch.Tensor] = []
value_caches: list[torch.Tensor] = []

for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
key_value_cache = torch.empty(size=kv_cache_allocation_shape,
dtype=torch_dtype,
device=device)
device=device).permute(*stride_order)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
Expand Down
23 changes: 18 additions & 5 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,32 @@ def _allocate_kv_cache(
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))

# The allocation respects the backend-defined stride order to ensure
# the semantic remains consistent for each backend. We first obtain the
# generic kv cache shape and then permute it according to the stride
# order which could result in a non-contiguous tensor.
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
for i in kv_cache_stride_order)

for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
layer_kv_cache = torch.zeros(
kv_cache_allocation_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device).permute(*kv_cache_stride_order)

# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
Expand Down