From c209a59fb67f4a745d8d6501c40814fa6d1273fa Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 1 Mar 2025 04:54:19 -0800 Subject: [PATCH 1/8] kv cache config refactor Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 102 ++++++++++++++++++++- vllm/v1/core/kv_cache_utils.py | 127 +++++++++++++++++++-------- vllm/v1/engine/core.py | 25 ++++-- vllm/v1/executor/abstract.py | 15 ++-- vllm/v1/kv_cache_interface.py | 48 ++++++---- vllm/v1/worker/gpu_model_runner.py | 37 ++++---- vllm/v1/worker/gpu_worker.py | 4 +- vllm/v1/worker/tpu_model_runner.py | 44 +++++----- vllm/v1/worker/tpu_worker.py | 2 +- vllm/v1/worker/worker_base.py | 4 +- 10 files changed, 293 insertions(+), 115 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index ba08b83ec54..ab7486bf01e 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams @@ -8,7 +9,10 @@ KVCacheBlock, PrefixCachingMetrics, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens) + hash_request_tokens, + make_kv_cache_configs_consistent) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheTensor, VirtualLayer) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -314,3 +318,99 @@ def stats(requests, queries, hits): assert metrics.aggregated_query_total == 0 assert metrics.aggregated_query_hit == 0 assert not metrics.query_queue + + +def test_make_kv_cache_configs_consistent(): + + def new_kv_cache_spec(block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32): + return FullAttentionSpec(block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype) + + same_kv_cache_config = [ + KVCacheConfig( + num_blocks=10, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + virtual_layers=[ + VirtualLayer(["layer1"], new_kv_cache_spec()), + VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + ], + ), + KVCacheConfig( + num_blocks=20, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + virtual_layers=[ + VirtualLayer(["layer1"], new_kv_cache_spec()), + VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + ], + ), + ] + make_kv_cache_configs_consistent(same_kv_cache_config) + assert same_kv_cache_config[0].num_blocks == 10 + assert same_kv_cache_config[1].num_blocks == 10 + + need_sort_kv_cache_config = [ + KVCacheConfig( + num_blocks=10, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + virtual_layers=[ + VirtualLayer(["layer1"], new_kv_cache_spec()), + VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + ], + ), + KVCacheConfig( + num_blocks=20, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + virtual_layers=[ + VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + VirtualLayer(["layer1"], new_kv_cache_spec()), + ], + ), + ] + + make_kv_cache_configs_consistent(need_sort_kv_cache_config) + assert need_sort_kv_cache_config[0].num_blocks == 10 + assert need_sort_kv_cache_config[1].num_blocks == 10 + + diff_kv_cache_config = [ + KVCacheConfig( + num_blocks=10, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + virtual_layers=[ + VirtualLayer(["layer1"], new_kv_cache_spec()), + VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + ], + ), + KVCacheConfig( + num_blocks=20, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + virtual_layers=[ + VirtualLayer(["layer1"], new_kv_cache_spec()), + VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=8)), + ], + ), + ] + with pytest.raises(AssertionError): + make_kv_cache_configs_consistent(diff_kv_cache_config) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e3eb6b24c19..62a9040c502 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -3,12 +3,12 @@ from collections import deque from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, - KVCacheTensor) + KVCacheTensor, VirtualLayer) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -436,7 +436,7 @@ def hash_request_tokens(block_size: int, def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: Dict[str, KVCacheSpec], available_memory: int): """ Checks whether `available_memory` is enough for the KV cache to hold at @@ -444,7 +444,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The kv cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Raises: @@ -471,12 +471,42 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") -def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: +def create_virtual_layer( + kv_cache_spec: Dict[str, KVCacheSpec], + virtual_layer_map: List[List[str]]) -> List[VirtualLayer]: + """ + Create VirtualLayer object for each virtual layer. + The layers represented by the same virtual layer should share the same + KVCacheSpec. + + Args: + kv_cache_spec: + A mapping from each layer name to its corresponding KVCacheSpec. + virtual_layer_map: + A list of virtual layers, where each element is a list of layer + names that represented by the same virtual layer and should share + the same KVCacheSpec. + Returns: + A list of VirtualLayer objects, one for each virtual layer. + """ + virtual_layers = [] + for layer_names in virtual_layer_map: + layer_spec = kv_cache_spec[layer_names[0]] + assert all( + kv_cache_spec[layer_name] == layer_spec + for layer_name in layer_names[1:] + ), ("All layers represented by one virtual layer must share the same " + "KVCacheSpec.") + virtual_layers.append(VirtualLayer(layer_names, layer_spec)) + return virtual_layers + + +def is_kv_cache_type_uniform(kv_cache_spec: Dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same type of KV cache. Args: - kv_cache_spec: The KVCacheSpec of the model + kv_cache_spec: The kv cache spec of each attention layer in the model Returns: True if all layers have the same type, False otherwise. @@ -487,18 +517,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, - available_memory: int, - num_layers: int) -> KVCacheConfig: + kv_cache_spec: Dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The kv cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes. - num_layers: The number of layers in the model. Returns: The generated KVCacheConfig @@ -508,7 +536,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, assert len(page_sizes) == 1 page_size = page_sizes.pop() - num_blocks = int(available_memory // page_size // num_layers) + num_blocks = int(available_memory // page_size // len(kv_cache_spec)) num_blocks = max(num_blocks, 0) if vllm_config.cache_config.num_gpu_blocks_override is not None: @@ -528,6 +556,8 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, max_model_len_str, max_concurrency) per_layer_size = page_size * num_blocks + # All layers can be represented by the same virtual layer. + virtual_layer_map = [[layer_name for layer_name in kv_cache_spec]] kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -535,41 +565,68 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, layer_name: KVCacheTensor(size=per_layer_size) for layer_name in kv_cache_spec }, - groups=[[layer_name for layer_name in kv_cache_spec]], - kv_cache_spec=kv_cache_spec) + virtual_layers=create_virtual_layer(kv_cache_spec, virtual_layer_map), + ) return kv_cache_config -def get_kv_cache_configs(vllm_config: VllmConfig, - kv_cache_specs: List[KVCacheSpec], - available_memory: int) -> List[KVCacheConfig]: +def get_kv_cache_config(vllm_config: VllmConfig, + kv_cache_spec: Dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model TODO: support hybrid models with more than one type of KV cache. Args: vllm_config: The global VllmConfig - kv_cache_specs: The kv cache specs of the model + kv_cache_spec: The kv cache specs of the model available_memory: Memory available for KV cache in bytes. Returns: The generated KVCacheConfigs """ - # Use the max number of layers to conservatively determine - # the number of blocks. - num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs) - kv_cache_configs = [] - for kv_cache_spec in kv_cache_specs: - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, - available_memory) - if is_kv_cache_type_uniform(kv_cache_spec): - # KV cache of all layers are the same, which is true for - # most models. Allocate the same amount of memory for - # each layer. - kv_cache_configs.append( - _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory, - num_layers)) - else: - raise NotImplementedError + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + + raise NotImplementedError + + +def make_kv_cache_configs_consistent(kv_cache_configs: List[KVCacheConfig]): + """ + Make the KV cache configurations for each worker consistent. + This function verifies that the virtual layers of each worker are the same, + and change the num_blocks of each worker to the smallest among all workers. + + Args: + kv_cache_configs (List[KVCacheConfig]): + The KV cache configurations for each worker. Will be in-place + modified to make them consistent. + """ + + # Sort the virtual layers by the type_id of the KV cache spec. + # This can avoid the inconsistency caused by the order of virtual layers. + for kv_cache_config in kv_cache_configs: + kv_cache_config.virtual_layers.sort( + key=lambda x: x.kv_cache_spec.type_id) + + # Verify that the virtual layers of each rank are the same. + for kv_cache_config in kv_cache_configs[1:]: + for virtual_layer1, virtual_layer2 in zip( + kv_cache_configs[0].virtual_layers, + kv_cache_config.virtual_layers): + assert virtual_layer1.kv_cache_spec == virtual_layer2.kv_cache_spec + + # Change the num_blocks of each rank to the smallest among all ranks. We + # do not need to shrink the tensor size because it is valid to only use the + # first `num_blocks` blocks of the tensor. + num_blocks = min(kv_cache_config.num_blocks + for kv_cache_config in kv_cache_configs) + for kv_cache_config in kv_cache_configs: + kv_cache_config.num_blocks = num_blocks + return kv_cache_configs diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 041896f1c7c..ae5f706508c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -20,7 +20,8 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx -from vllm.v1.core.kv_cache_utils import get_kv_cache_configs +from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, + make_kv_cache_configs_consistent) from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) @@ -99,14 +100,20 @@ def _initialize_kv_caches(self, available_gpu_memory = self.model_executor.determine_available_memory() # Get the kv cache tensor size - kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, - available_gpu_memory) - num_gpu_blocks_set = set(config.num_blocks - for config in kv_cache_configs) - assert len(num_gpu_blocks_set) == 1, ( - f"num_gpu_blocks need to be the same across workers, " - f"but they are different: {num_gpu_blocks_set}") - num_gpu_blocks = num_gpu_blocks_set.pop() + kv_cache_configs = [ + get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, + available_gpu_memory_one_worker) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker in + zip(kv_cache_specs, available_gpu_memory) + ] + + # Since we use a shared centralized controller, we need the + # `kv_cache_config` to be consistent across all workers to make sure + # all the memory operators can be applied to all workers. + make_kv_cache_configs_consistent(kv_cache_configs) + + # The kv cache configs + num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 # Initialize kv cache and warmup the execution diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 11002ad0022..66ba8df9a81 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from concurrent.futures import Future -from typing import List, Type, Union +from typing import Dict, List, Type, Union import torch import torch.distributed as dist @@ -62,14 +62,11 @@ def initialize_from_config(self, args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") - def determine_available_memory(self) -> int: # in bytes + def determine_available_memory(self) -> List[int]: # in bytes output = self.collective_rpc("determine_available_memory") - # Since we use a shared centralized controller, we take the minimum - # memory size across all workers to make sure all the memory - # operators can be applied to all workers. - return min(output) + return output - def get_kv_cache_specs(self) -> List[KVCacheSpec]: + def get_kv_cache_specs(self) -> List[Dict[str, KVCacheSpec]]: output = self.collective_rpc("get_kv_cache_spec") return output @@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - def determine_available_memory(self) -> int: # in bytes + def determine_available_memory(self) -> List[int]: # in bytes # same as determine_num_available_blocks in v0, # we need to get the min across all ranks. memory = super().determine_available_memory() @@ -103,4 +100,4 @@ def determine_available_memory(self) -> int: # in bytes cpu_group = get_world_group().cpu_group memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - return memory_tensor.item() + return [memory_tensor.item()] diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index eddfb5949eb..c6ca7dca784 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -12,7 +12,7 @@ @dataclass -class KVCacheSpecBase: +class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. """ @@ -56,7 +56,7 @@ def bytes_for_tokens(self, num_tokens: int) -> int: @dataclass -class FullAttentionSpec(KVCacheSpecBase): +class FullAttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype @@ -74,9 +74,6 @@ def bytes_for_tokens(self, num_tokens: int) -> int: return cdiv(num_tokens, self.block_size) * self.page_size_bytes -KVCacheSpec = Dict[str, KVCacheSpecBase] - - @dataclass class KVCacheTensor: """ @@ -87,6 +84,18 @@ class KVCacheTensor: size: int # The size of KV cache Tensor in bytes +@dataclass +class VirtualLayer: + """ + A dataclass for specifying a virtual layer, which represents multiple layers + that can share the same block_table. + """ + # The names of layers represented by this virtual layer + layer_names: List[str] + # The KV cache spec of this virtual layer + kv_cache_spec: KVCacheSpec + + @dataclass class KVCacheConfig: """ @@ -97,17 +106,20 @@ class KVCacheConfig: """layer_name -> how to initialize KV cache for that layer""" tensors: Dict[str, KVCacheTensor] """ - A list of kv-cache groups. Each group includes a set of layers with - the same kv-cache spec, and the total page_size of layers inside a group - is same across all groups (as the KVCacheManager only supports allocating - pages of the same size). For example: - 1. A model only uses full attention: one group with all layers in the model. - 2. (not implemented yet) A model with the same number of full attention - layers and sliding window attention layers: two groups, one for full - attention layers and one for sliding window attention layers. - 3. (not implemented yet) A model with 2 full attention layers and 4 sliding - window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). + The virtual_layers of the model. + The layers in the models are repeated with some patterns, e.g., a model + with 10 full attention layers and 20 sliding window attention layers can be + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. And we regard + this pattern as virtual layers (3 virtual layers in this case, each + representing 10 layers). + The KVCacheManager allocates the blocks for each virtual layer, and the + model runner applies the block table of the virtual layer to all layers + represented by it. + For example: + 1. A model only uses full attention. There is only one virtual layer, + and the block table is shared by all layers. + 2. (WIP) A model with 10 full attention layers and 20 sliding window + attention. There are 3 virtual layers (1 * full, 2 * sw), and the block + table of each virtual layer is shared by 10 layers of the same type. """ - groups: List[List[str]] - """the KVCacheSpec of the model""" - kv_cache_spec: KVCacheSpec + virtual_layers: List[VirtualLayer] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0215b273538..d407245cbf7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1360,34 +1360,37 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.groups) > 1: + if len(kv_cache_config.virtual_layers) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: Dict[str, torch.Tensor] = {} - for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % layer_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // layer_spec.page_size_bytes - if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, - layer_spec.head_size) - dtype = layer_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - else: - raise NotImplementedError + for virtual_layer in kv_cache_config.virtual_layers: + kv_cache_spec = virtual_layer.kv_cache_spec + for layer_name in virtual_layer.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + assert num_blocks >= kv_cache_config.num_blocks + if isinstance(kv_cache_spec, FullAttentionSpec): + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + kv_caches[layer_name] = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + else: + raise NotImplementedError bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. @@ -1398,7 +1401,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: KVCacheSpec = {} + kv_cache_spec: Dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f681925f557..62b55576750 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional, Set +from typing import TYPE_CHECKING, Dict, Optional, Set import torch import torch.distributed @@ -182,7 +182,7 @@ def determine_available_memory(self) -> int: return int(available_kv_cache_memory) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2c6a0371cde..3af2f53ead4 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -272,7 +272,7 @@ def get_model(self) -> nn.Module: assert self.model is not None return self.model - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. @@ -283,7 +283,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: KVCacheSpec = {} + kv_cache_spec: Dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. @@ -607,31 +607,33 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.groups) > 1: + if len(kv_cache_config.virtual_layers) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: Dict[str, torch.Tensor] = {} - for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % layer_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // layer_spec.page_size_bytes - if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, - layer_spec.head_size) - dtype = layer_spec.dtype - - tpu_k_cache = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - tpu_v_cache = torch.zeros_like(tpu_k_cache) - - kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) - else: - raise NotImplementedError + for virtual_layer in kv_cache_config.virtual_layers: + kv_cache_spec = virtual_layer.kv_cache_spec + for layer_name in virtual_layer.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, FullAttentionSpec): + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + tpu_k_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + + kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) + else: + raise NotImplementedError bind_kv_cache( kv_caches, diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 405dc628ee1..09742065c1e 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -166,7 +166,7 @@ def compile_or_warm_up_model(self) -> None: def get_model(self) -> nn.Module: return self.model_runner.get_model() - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 51d2da2344b..a45b45637e8 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Dict, Optional import torch import torch.nn as nn @@ -51,7 +51,7 @@ def __init__( self.device: Optional[torch.device] = None self.model_runner: Optional[nn.Module] = None - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> Dict[str, KVCacheSpec]: """Get specifications for KV cache implementation.""" raise NotImplementedError From 2b30e352a5cbf509ef8e8e9e3b2b0fc2f736f33b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 1 Mar 2025 05:22:09 -0800 Subject: [PATCH 2/8] update comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 12 ++++++------ vllm/v1/engine/core.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 62a9040c502..88abaf15006 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -579,7 +579,7 @@ def get_kv_cache_config(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache specs of the model + kv_cache_spec: The kv cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Returns: @@ -598,14 +598,14 @@ def get_kv_cache_config(vllm_config: VllmConfig, def make_kv_cache_configs_consistent(kv_cache_configs: List[KVCacheConfig]): """ - Make the KV cache configurations for each worker consistent. + Make the KV cache configurations for each worker consistent, so that all + workers can be controlled by the same KVCacheManager. This function verifies that the virtual layers of each worker are the same, - and change the num_blocks of each worker to the smallest among all workers. + and changes the num_blocks of each worker to the smallest among all workers. Args: - kv_cache_configs (List[KVCacheConfig]): - The KV cache configurations for each worker. Will be in-place - modified to make them consistent. + kv_cache_configs: The KV cache configurations for each worker. Will be + in-place modified to make them consistent. """ # Sort the virtual layers by the type_id of the KV cache spec. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ae5f706508c..affac324311 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -112,7 +112,8 @@ def _initialize_kv_caches(self, # all the memory operators can be applied to all workers. make_kv_cache_configs_consistent(kv_cache_configs) - # The kv cache configs + # All workers have the same kv_cache_config except layer names, so use + # an arbitrary one to get the number of blocks. num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 From 93adab8472d531d844a5f469347456feb5655daf Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 8 Mar 2025 22:25:48 -0800 Subject: [PATCH 3/8] address review comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 88abaf15006..778e02a56d4 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -557,7 +557,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, per_layer_size = page_size * num_blocks # All layers can be represented by the same virtual layer. - virtual_layer_map = [[layer_name for layer_name in kv_cache_spec]] + virtual_layer_map = [list(kv_cache_spec.keys())] kv_cache_config = KVCacheConfig( num_blocks=num_blocks, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d407245cbf7..d048e05870e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1373,6 +1373,13 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, FullAttentionSpec): kv_cache_shape = self.attn_backend.get_kv_cache_shape( From 4d056261d411d12baa74f1b3853debbfe033775a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 9 Mar 2025 09:24:34 -0700 Subject: [PATCH 4/8] ManagerKVLayer Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 22 +++++++++---------- vllm/v1/kv_cache_interface.py | 34 +++++++++++++++++------------- vllm/v1/worker/gpu_model_runner.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 4 ++-- 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 37c8a2e3240..9a2e228b2e3 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -8,7 +8,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, - KVCacheTensor, VirtualLayer) + KVCacheTensor, ManagerKVLayer) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -486,9 +486,9 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, def create_virtual_layer( kv_cache_spec: dict[str, KVCacheSpec], - virtual_layer_map: list[list[str]]) -> list[VirtualLayer]: + virtual_layer_map: list[list[str]]) -> list[ManagerKVLayer]: """ - Create VirtualLayer object for each virtual layer. + Create ManagerKVLayer object for each virtual layer. The layers represented by the same virtual layer should share the same KVCacheSpec. @@ -500,9 +500,9 @@ def create_virtual_layer( names that represented by the same virtual layer and should share the same KVCacheSpec. Returns: - A list of VirtualLayer objects, one for each virtual layer. + A list of ManagerKVLayer objects, one for each virtual layer. """ - virtual_layers = [] + manager_layers = [] for layer_names in virtual_layer_map: layer_spec = kv_cache_spec[layer_names[0]] assert all( @@ -510,8 +510,8 @@ def create_virtual_layer( for layer_name in layer_names[1:] ), ("All layers represented by one virtual layer must share the same " "KVCacheSpec.") - virtual_layers.append(VirtualLayer(layer_names, layer_spec)) - return virtual_layers + manager_layers.append(ManagerKVLayer(layer_names, layer_spec)) + return manager_layers def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: @@ -578,7 +578,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, layer_name: KVCacheTensor(size=per_layer_size) for layer_name in kv_cache_spec }, - virtual_layers=create_virtual_layer(kv_cache_spec, virtual_layer_map), + manager_layers=create_virtual_layer(kv_cache_spec, virtual_layer_map), ) return kv_cache_config @@ -624,14 +624,14 @@ def make_kv_cache_configs_consistent(kv_cache_configs: list[KVCacheConfig]): # Sort the virtual layers by the type_id of the KV cache spec. # This can avoid the inconsistency caused by the order of virtual layers. for kv_cache_config in kv_cache_configs: - kv_cache_config.virtual_layers.sort( + kv_cache_config.manager_layers.sort( key=lambda x: x.kv_cache_spec.type_id) # Verify that the virtual layers of each rank are the same. for kv_cache_config in kv_cache_configs[1:]: for virtual_layer1, virtual_layer2 in zip( - kv_cache_configs[0].virtual_layers, - kv_cache_config.virtual_layers): + kv_cache_configs[0].manager_layers, + kv_cache_config.manager_layers): assert virtual_layer1.kv_cache_spec == virtual_layer2.kv_cache_spec # Change the num_blocks of each rank to the smallest among all ranks. We diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index d5498c18506..e3a67f686b5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -87,14 +87,14 @@ class KVCacheTensor: @dataclass -class VirtualLayer: +class ManagerKVLayer: """ - A dataclass for specifying a virtual layer, which represents multiple layers - that can share the same block_table. + Represents a set of model layers that share the same KV cache block table. + These layers are regarded as one layer in the KV cache manager. """ - # The names of layers represented by this virtual layer + # The names of model layers represented by this manager layer layer_names: list[str] - # The KV cache spec of this virtual layer + # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec @@ -108,20 +108,24 @@ class KVCacheConfig: """layer_name -> how to initialize KV cache for that layer""" tensors: dict[str, KVCacheTensor] """ - The virtual_layers of the model. + The manager_layers of the model. The layers in the models are repeated with some patterns, e.g., a model with 10 full attention layers and 20 sliding window attention layers can be - regarded as repeating the pattern (1 * full, 2 * sw) 10 times. And we regard - this pattern as virtual layers (3 virtual layers in this case, each - representing 10 layers). - The KVCacheManager allocates the blocks for each virtual layer, and the - model runner applies the block table of the virtual layer to all layers + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + The KVCacheManager allocate different block tables for each of the 3 layers + in the pattern, and repeat each of them 10 times to generate the + block_table for the 30 layers in the model. + From the view of KVCacheManager, there are only 3 layers, so we call the 3 + layers in the pattern "manager layers". + + The KVCacheManager allocates the blocks for each manager layer, and the + model runner applies the block table of the manager layer to all layers represented by it. For example: - 1. A model only uses full attention. There is only one virtual layer, + 1. A model only uses full attention. There is only one manager layer, and the block table is shared by all layers. 2. (WIP) A model with 10 full attention layers and 20 sliding window - attention. There are 3 virtual layers (1 * full, 2 * sw), and the block - table of each virtual layer is shared by 10 layers of the same type. + attention. There are 3 manager layers (1 * full, 2 * sw), and the block + table of each manager layer is shared by 10 layers of the same type. """ - virtual_layers: list[VirtualLayer] + manager_layers: list[ManagerKVLayer] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 402a1fbe799..e1cb2b2fee7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1423,14 +1423,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.virtual_layers) > 1: + if len(kv_cache_config.manager_layers) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: dict[str, torch.Tensor] = {} - for virtual_layer in kv_cache_config.virtual_layers: + for virtual_layer in kv_cache_config.manager_layers: kv_cache_spec = virtual_layer.kv_cache_spec for layer_name in virtual_layer.layer_names: tensor_config = kv_cache_config.tensors[layer_name] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e45da730acf..38be49f1716 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -758,14 +758,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.virtual_layers) > 1: + if len(kv_cache_config.manager_layers) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: dict[str, torch.Tensor] = {} - for virtual_layer in kv_cache_config.virtual_layers: + for virtual_layer in kv_cache_config.manager_layers: kv_cache_spec = virtual_layer.kv_cache_spec for layer_name in virtual_layer.layer_names: tensor_config = kv_cache_config.tensors[layer_name] From 530d4bffa7e5be7409584ff3fadcc33a8deb9783 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 18 Mar 2025 06:12:54 -0700 Subject: [PATCH 5/8] update names Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 54 +++++++++++--------- vllm/v1/core/kv_cache_utils.py | 73 +++++++++++++++------------- vllm/v1/engine/core.py | 9 +++- vllm/v1/kv_cache_interface.py | 13 +++-- vllm/v1/worker/gpu_model_runner.py | 8 +-- vllm/v1/worker/tpu_model_runner.py | 8 +-- 6 files changed, 89 insertions(+), 76 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index ab7486bf01e..283f8369d50 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -10,9 +10,9 @@ generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens, - make_kv_cache_configs_consistent) + unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheTensor, VirtualLayer) + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -320,7 +320,7 @@ def stats(requests, queries, hits): assert not metrics.query_queue -def test_make_kv_cache_configs_consistent(): +def test_unify_kv_cache_configs(): def new_kv_cache_spec(block_size=16, num_kv_heads=2, @@ -338,9 +338,10 @@ def new_kv_cache_spec(block_size=16, "layer1": KVCacheTensor(100), "layer2": KVCacheTensor(100), }, - virtual_layers=[ - VirtualLayer(["layer1"], new_kv_cache_spec()), - VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), ], ), KVCacheConfig( @@ -349,13 +350,14 @@ def new_kv_cache_spec(block_size=16, "layer1": KVCacheTensor(100), "layer2": KVCacheTensor(100), }, - virtual_layers=[ - VirtualLayer(["layer1"], new_kv_cache_spec()), - VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), ], ), ] - make_kv_cache_configs_consistent(same_kv_cache_config) + unify_kv_cache_configs(same_kv_cache_config) assert same_kv_cache_config[0].num_blocks == 10 assert same_kv_cache_config[1].num_blocks == 10 @@ -366,9 +368,10 @@ def new_kv_cache_spec(block_size=16, "layer1": KVCacheTensor(100), "layer2": KVCacheTensor(100), }, - virtual_layers=[ - VirtualLayer(["layer1"], new_kv_cache_spec()), - VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), ], ), KVCacheConfig( @@ -377,14 +380,15 @@ def new_kv_cache_spec(block_size=16, "layer1": KVCacheTensor(100), "layer2": KVCacheTensor(100), }, - virtual_layers=[ - VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), - VirtualLayer(["layer1"], new_kv_cache_spec()), + kv_cache_groups=[ + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), ], ), ] - make_kv_cache_configs_consistent(need_sort_kv_cache_config) + unify_kv_cache_configs(need_sort_kv_cache_config) assert need_sort_kv_cache_config[0].num_blocks == 10 assert need_sort_kv_cache_config[1].num_blocks == 10 @@ -395,9 +399,10 @@ def new_kv_cache_spec(block_size=16, "layer1": KVCacheTensor(100), "layer2": KVCacheTensor(100), }, - virtual_layers=[ - VirtualLayer(["layer1"], new_kv_cache_spec()), - VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=4)), + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), ], ), KVCacheConfig( @@ -406,11 +411,12 @@ def new_kv_cache_spec(block_size=16, "layer1": KVCacheTensor(100), "layer2": KVCacheTensor(100), }, - virtual_layers=[ - VirtualLayer(["layer1"], new_kv_cache_spec()), - VirtualLayer(["layer2"], new_kv_cache_spec(num_kv_heads=8)), + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=8)), ], ), ] with pytest.raises(AssertionError): - make_kv_cache_configs_consistent(diff_kv_cache_config) + unify_kv_cache_configs(diff_kv_cache_config) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9a2e228b2e3..600d5a66de4 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -7,8 +7,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, - KVCacheTensor, ManagerKVLayer) +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec, + KVCacheSpec, KVCacheTensor) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -484,34 +484,35 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") -def create_virtual_layer( +def create_kv_cache_group_spec( kv_cache_spec: dict[str, KVCacheSpec], - virtual_layer_map: list[list[str]]) -> list[ManagerKVLayer]: + grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: """ - Create ManagerKVLayer object for each virtual layer. - The layers represented by the same virtual layer should share the same + Create KVCacheGroupSpec object for each kv cache group layer. + The layers in the same group should share the same KVCacheSpec. Args: kv_cache_spec: A mapping from each layer name to its corresponding KVCacheSpec. - virtual_layer_map: - A list of virtual layers, where each element is a list of layer - names that represented by the same virtual layer and should share - the same KVCacheSpec. + grouped_layer_names: + A list of kv cache groups, where each element is a list of layer + names that belong to the same group and should share the same + KVCacheSpec. Returns: - A list of ManagerKVLayer objects, one for each virtual layer. + A list of KVCacheGroupSpec objects, one for each group. """ - manager_layers = [] - for layer_names in virtual_layer_map: - layer_spec = kv_cache_spec[layer_names[0]] + kv_cache_groups = [] + for layer_names_one_group in grouped_layer_names: + layer_spec = kv_cache_spec[layer_names_one_group[0]] assert all( kv_cache_spec[layer_name] == layer_spec - for layer_name in layer_names[1:] - ), ("All layers represented by one virtual layer must share the same " - "KVCacheSpec.") - manager_layers.append(ManagerKVLayer(layer_names, layer_spec)) - return manager_layers + for layer_name in layer_names_one_group[1:]), ( + "All layers in the same KV cache group must share the same " + "KVCacheSpec.") + kv_cache_groups.append( + KVCacheGroupSpec(layer_names_one_group, layer_spec)) + return kv_cache_groups def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: @@ -569,8 +570,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, max_model_len_str, max_concurrency) per_layer_size = page_size * num_blocks - # All layers can be represented by the same virtual layer. - virtual_layer_map = [list(kv_cache_spec.keys())] + # All layers have the same KV cache spec, so we create one kv cache group + # for all layers. + grouped_layer_names = [list(kv_cache_spec.keys())] kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -578,7 +580,8 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, layer_name: KVCacheTensor(size=per_layer_size) for layer_name in kv_cache_spec }, - manager_layers=create_virtual_layer(kv_cache_spec, virtual_layer_map), + kv_cache_groups=create_kv_cache_group_spec(kv_cache_spec, + grouped_layer_names), ) return kv_cache_config @@ -609,11 +612,11 @@ def get_kv_cache_config(vllm_config: VllmConfig, raise NotImplementedError -def make_kv_cache_configs_consistent(kv_cache_configs: list[KVCacheConfig]): +def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): """ Make the KV cache configurations for each worker consistent, so that all workers can be controlled by the same KVCacheManager. - This function verifies that the virtual layers of each worker are the same, + This function verifies that the layer group of each worker are the same, and changes the num_blocks of each worker to the smallest among all workers. Args: @@ -621,25 +624,25 @@ def make_kv_cache_configs_consistent(kv_cache_configs: list[KVCacheConfig]): in-place modified to make them consistent. """ - # Sort the virtual layers by the type_id of the KV cache spec. - # This can avoid the inconsistency caused by the order of virtual layers. + # Sort the kv cache groups by the type_id of their KV cache spec. + # This can avoid the inconsistency caused by the order of groups. for kv_cache_config in kv_cache_configs: - kv_cache_config.manager_layers.sort( + kv_cache_config.kv_cache_groups.sort( key=lambda x: x.kv_cache_spec.type_id) - # Verify that the virtual layers of each rank are the same. + # Verify that the groups of each rank are the same. for kv_cache_config in kv_cache_configs[1:]: - for virtual_layer1, virtual_layer2 in zip( - kv_cache_configs[0].manager_layers, - kv_cache_config.manager_layers): - assert virtual_layer1.kv_cache_spec == virtual_layer2.kv_cache_spec + for group_rank_0, group_rank_i in zip( + kv_cache_configs[0].kv_cache_groups, + kv_cache_config.kv_cache_groups): + assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec # Change the num_blocks of each rank to the smallest among all ranks. We # do not need to shrink the tensor size because it is valid to only use the # first `num_blocks` blocks of the tensor. - num_blocks = min(kv_cache_config.num_blocks - for kv_cache_config in kv_cache_configs) + min_num_blocks = min(kv_cache_config.num_blocks + for kv_cache_config in kv_cache_configs) for kv_cache_config in kv_cache_configs: - kv_cache_config.num_blocks = num_blocks + kv_cache_config.num_blocks = min_num_blocks return kv_cache_configs diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b4a4a52b5b3..388822d06df 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -21,7 +21,7 @@ maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, - make_kv_cache_configs_consistent) + unify_kv_cache_configs) from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) @@ -103,6 +103,7 @@ def _initialize_kv_caches(self, # memory can be allocated for kv cache. available_gpu_memory = self.model_executor.determine_available_memory() + assert len(kv_cache_specs) == len(available_gpu_memory) # Get the kv cache tensor size kv_cache_configs = [ get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, @@ -114,10 +115,14 @@ def _initialize_kv_caches(self, # Since we use a shared centralized controller, we need the # `kv_cache_config` to be consistent across all workers to make sure # all the memory operators can be applied to all workers. - make_kv_cache_configs_consistent(kv_cache_configs) + unify_kv_cache_configs(kv_cache_configs) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to get the number of blocks. + assert all([ + cfg.num_blocks == kv_cache_configs[0].num_blocks + for cfg in kv_cache_configs + ]) num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index e3a67f686b5..15c8c026d30 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -87,12 +87,12 @@ class KVCacheTensor: @dataclass -class ManagerKVLayer: +class KVCacheGroupSpec: """ - Represents a set of model layers that share the same KV cache block table. + Represents a group of model layers that share the same KV cache block table. These layers are regarded as one layer in the KV cache manager. """ - # The names of model layers represented by this manager layer + # The names of model layers in this group layer_names: list[str] # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec @@ -108,15 +108,14 @@ class KVCacheConfig: """layer_name -> how to initialize KV cache for that layer""" tensors: dict[str, KVCacheTensor] """ - The manager_layers of the model. + The kv cache groups of the model. The layers in the models are repeated with some patterns, e.g., a model with 10 full attention layers and 20 sliding window attention layers can be regarded as repeating the pattern (1 * full, 2 * sw) 10 times. The KVCacheManager allocate different block tables for each of the 3 layers in the pattern, and repeat each of them 10 times to generate the block_table for the 30 layers in the model. - From the view of KVCacheManager, there are only 3 layers, so we call the 3 - layers in the pattern "manager layers". + From the view of KVCacheManager, there are only 3 layers. The KVCacheManager allocates the blocks for each manager layer, and the model runner applies the block table of the manager layer to all layers @@ -128,4 +127,4 @@ class KVCacheConfig: attention. There are 3 manager layers (1 * full, 2 * sw), and the block table of each manager layer is shared by 10 layers of the same type. """ - manager_layers: list[ManagerKVLayer] + kv_cache_groups: list[KVCacheGroupSpec] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e1cb2b2fee7..8035e7aba8c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1423,16 +1423,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.manager_layers) > 1: + if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: dict[str, torch.Tensor] = {} - for virtual_layer in kv_cache_config.manager_layers: - kv_cache_spec = virtual_layer.kv_cache_spec - for layer_name in virtual_layer.layer_names: + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 38be49f1716..86622acdf4c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -758,16 +758,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.manager_layers) > 1: + if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: dict[str, torch.Tensor] = {} - for virtual_layer in kv_cache_config.manager_layers: - kv_cache_spec = virtual_layer.kv_cache_spec - for layer_name in virtual_layer.layer_names: + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes From 56e0b5db95837d5e6c088bb4b68acd04c06dfa3d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 18 Mar 2025 08:12:34 -0700 Subject: [PATCH 6/8] update comments Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d7f1c00a99d..35a85a579db 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1473,7 +1473,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: dtype=dtype, device=self.device) else: - raise NotImplementedError + # TODO: add new branches when introducing more types of + # KV cache specs. + raise ValueError("Unknown KV cache spec type.") bind_kv_cache( kv_caches, From 19b25897a30f6c2a2e76b9c4b247c5b60bb8872b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 18 Mar 2025 08:33:36 -0700 Subject: [PATCH 7/8] update the explaination of kv cache groups Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 6 +++--- vllm/v1/kv_cache_interface.py | 25 +++++++++++++------------ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 600d5a66de4..e0d7f4dbdc1 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -484,7 +484,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") -def create_kv_cache_group_spec( +def create_kv_cache_group_specs( kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: """ @@ -580,8 +580,8 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, layer_name: KVCacheTensor(size=per_layer_size) for layer_name in kv_cache_spec }, - kv_cache_groups=create_kv_cache_group_spec(kv_cache_spec, - grouped_layer_names), + kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, + grouped_layer_names), ) return kv_cache_config diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 15c8c026d30..867b1b61c87 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -112,19 +112,20 @@ class KVCacheConfig: The layers in the models are repeated with some patterns, e.g., a model with 10 full attention layers and 20 sliding window attention layers can be regarded as repeating the pattern (1 * full, 2 * sw) 10 times. - The KVCacheManager allocate different block tables for each of the 3 layers - in the pattern, and repeat each of them 10 times to generate the - block_table for the 30 layers in the model. - From the view of KVCacheManager, there are only 3 layers. - - The KVCacheManager allocates the blocks for each manager layer, and the - model runner applies the block table of the manager layer to all layers - represented by it. + The KVCacheManager allocates different block tables for each of the 3 layers + in the pattern, and repeats each of them 10 times to generate the + block_table for the 30 layers in the model. + Therefore, we can group the layers in the model into 3 groups, each of which + contains 10 layers in the model. + The KVCacheManager allocates the block_table for each group based on its + kv_cache spec, and the model runner applies the block table to each layer + in the group. For example: - 1. A model only uses full attention. There is only one manager layer, - and the block table is shared by all layers. + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. 2. (WIP) A model with 10 full attention layers and 20 sliding window - attention. There are 3 manager layers (1 * full, 2 * sw), and the block - table of each manager layer is shared by 10 layers of the same type. + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + there are 3 groups, each of which represents 10 layers in the model. """ kv_cache_groups: list[KVCacheGroupSpec] From 757f350ca35defbdf88df04b223041eb79745cdb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 20 Mar 2025 19:30:44 -0700 Subject: [PATCH 8/8] fix tests Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 283f8369d50..3fecb517c43 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -325,11 +325,13 @@ def test_unify_kv_cache_configs(): def new_kv_cache_spec(block_size=16, num_kv_heads=2, head_size=64, - dtype=torch.float32): + dtype=torch.float32, + use_mla=False): return FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, - dtype=dtype) + dtype=dtype, + use_mla=use_mla) same_kv_cache_config = [ KVCacheConfig(