Skip to content

[v1] Refactor KVCacheConfig #14079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 21, 2025
Merged
110 changes: 109 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, PrefixCachingMetrics,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
hash_request_tokens,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request

Expand Down Expand Up @@ -314,3 +318,107 @@ def stats(requests, queries, hits):
assert metrics.aggregated_query_total == 0
assert metrics.aggregated_query_hit == 0
assert not metrics.query_queue


def test_unify_kv_cache_configs():

def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)

same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
]
unify_kv_cache_configs(same_kv_cache_config)
assert same_kv_cache_config[0].num_blocks == 10
assert same_kv_cache_config[1].num_blocks == 10

need_sort_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
],
),
]

unify_kv_cache_configs(need_sort_kv_cache_config)
assert need_sort_kv_cache_config[0].num_blocks == 10
assert need_sort_kv_cache_config[1].num_blocks == 10

diff_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=8)),
],
),
]
with pytest.raises(AssertionError):
unify_kv_cache_configs(diff_kv_cache_config)
130 changes: 95 additions & 35 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
KVCacheTensor)
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request

Expand Down Expand Up @@ -449,15 +449,15 @@ def hash_request_tokens(block_size: int,


def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len.

Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.

Raises:
Expand All @@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
f"`max_model_len` when initializing the engine.")


def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def create_kv_cache_group_specs(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.

Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_spec = kv_cache_spec[layer_names_one_group[0]]
assert all(
kv_cache_spec[layer_name] == layer_spec
for layer_name in layer_names_one_group[1:]), (
"All layers in the same KV cache group must share the same "
"KVCacheSpec.")
kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, layer_spec))
return kv_cache_groups


def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.

Args:
kv_cache_spec: The KVCacheSpec of the model
kv_cache_spec: The kv cache spec of each attention layer in the model

Returns:
True if all layers have the same type, False otherwise.
Expand All @@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:


def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int,
num_layers: int) -> KVCacheConfig:
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.

Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.

Returns:
The generated KVCacheConfig
Expand All @@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert len(page_sizes) == 1
page_size = page_sizes.pop()

num_blocks = int(available_memory // page_size // num_layers)
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = max(num_blocks, 0)

if vllm_config.cache_config.num_gpu_blocks_override is not None:
Expand All @@ -541,48 +570,79 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
max_model_len_str, max_concurrency)

per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]

kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
tensors={
layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec
},
groups=[[layer_name for layer_name in kv_cache_spec]],
kv_cache_spec=kv_cache_spec)
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
return kv_cache_config


def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: list[KVCacheSpec],
available_memory: int) -> list[KVCacheConfig]:
def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.

Args:
vllm_config: The global VllmConfig
kv_cache_specs: The kv cache specs of the model
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.

Returns:
The generated KVCacheConfigs
"""
# Use the max number of layers to conservatively determine
# the number of blocks.
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
kv_cache_configs = []
for kv_cache_spec in kv_cache_specs:
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs.append(
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory,
num_layers))
else:
raise NotImplementedError
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)

raise NotImplementedError


def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
"""
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.

Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
"""

# Sort the kv cache groups by the type_id of their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(
key=lambda x: x.kv_cache_spec.type_id)

# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for group_rank_0, group_rank_i in zip(
kv_cache_configs[0].kv_cache_groups,
kv_cache_config.kv_cache_groups):
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec

# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the
# first `num_blocks` blocks of the tensor.
min_num_blocks = min(kv_cache_config.num_blocks
for kv_cache_config in kv_cache_configs)
for kv_cache_config in kv_cache_configs:
kv_cache_config.num_blocks = min_num_blocks

return kv_cache_configs
31 changes: 22 additions & 9 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
maybe_register_config_serialize_by_value)
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
Expand Down Expand Up @@ -120,15 +121,27 @@ def _initialize_kv_caches(self,
# memory can be allocated for kv cache.
available_gpu_memory = self.model_executor.determine_available_memory()

assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
num_gpu_blocks_set = set(config.num_blocks
for config in kv_cache_configs)
assert len(num_gpu_blocks_set) == 1, (
f"num_gpu_blocks need to be the same across workers, "
f"but they are different: {num_gpu_blocks_set}")
num_gpu_blocks = num_gpu_blocks_set.pop()
kv_cache_configs = [
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
available_gpu_memory_one_worker)
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
zip(kv_cache_specs, available_gpu_memory)
]

# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs(kv_cache_configs)

# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0

# Initialize kv cache and warmup the execution
Expand Down
Loading