diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index ba08b83ec54..3fecb517c43 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams @@ -8,7 +9,10 @@ KVCacheBlock, PrefixCachingMetrics, generate_block_hash_extra_keys, hash_block_tokens, - hash_request_tokens) + hash_request_tokens, + unify_kv_cache_configs) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -314,3 +318,107 @@ def stats(requests, queries, hits): assert metrics.aggregated_query_total == 0 assert metrics.aggregated_query_hit == 0 assert not metrics.query_queue + + +def test_unify_kv_cache_configs(): + + def new_kv_cache_spec(block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + use_mla=False): + return FullAttentionSpec(block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + use_mla=use_mla) + + same_kv_cache_config = [ + KVCacheConfig( + num_blocks=10, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), + ], + ), + KVCacheConfig( + num_blocks=20, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), + ], + ), + ] + unify_kv_cache_configs(same_kv_cache_config) + assert same_kv_cache_config[0].num_blocks == 10 + assert same_kv_cache_config[1].num_blocks == 10 + + need_sort_kv_cache_config = [ + KVCacheConfig( + num_blocks=10, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), + ], + ), + KVCacheConfig( + num_blocks=20, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + kv_cache_groups=[ + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + ], + ), + ] + + unify_kv_cache_configs(need_sort_kv_cache_config) + assert need_sort_kv_cache_config[0].num_blocks == 10 + assert need_sort_kv_cache_config[1].num_blocks == 10 + + diff_kv_cache_config = [ + KVCacheConfig( + num_blocks=10, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=4)), + ], + ), + KVCacheConfig( + num_blocks=20, + tensors={ + "layer1": KVCacheTensor(100), + "layer2": KVCacheTensor(100), + }, + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer2"], + new_kv_cache_spec(num_kv_heads=8)), + ], + ), + ] + with pytest.raises(AssertionError): + unify_kv_cache_configs(diff_kv_cache_config) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index adadcab5ea1..e0d7f4dbdc1 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -7,8 +7,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, - KVCacheTensor) +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec, + KVCacheSpec, KVCacheTensor) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -449,7 +449,7 @@ def hash_request_tokens(block_size: int, def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: dict[str, KVCacheSpec], available_memory: int): """ Checks whether `available_memory` is enough for the KV cache to hold at @@ -457,7 +457,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The kv cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Raises: @@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") -def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: +def create_kv_cache_group_specs( + kv_cache_spec: dict[str, KVCacheSpec], + grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: + """ + Create KVCacheGroupSpec object for each kv cache group layer. + The layers in the same group should share the same + KVCacheSpec. + + Args: + kv_cache_spec: + A mapping from each layer name to its corresponding KVCacheSpec. + grouped_layer_names: + A list of kv cache groups, where each element is a list of layer + names that belong to the same group and should share the same + KVCacheSpec. + Returns: + A list of KVCacheGroupSpec objects, one for each group. + """ + kv_cache_groups = [] + for layer_names_one_group in grouped_layer_names: + layer_spec = kv_cache_spec[layer_names_one_group[0]] + assert all( + kv_cache_spec[layer_name] == layer_spec + for layer_name in layer_names_one_group[1:]), ( + "All layers in the same KV cache group must share the same " + "KVCacheSpec.") + kv_cache_groups.append( + KVCacheGroupSpec(layer_names_one_group, layer_spec)) + return kv_cache_groups + + +def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same type of KV cache. Args: - kv_cache_spec: The KVCacheSpec of the model + kv_cache_spec: The kv cache spec of each attention layer in the model Returns: True if all layers have the same type, False otherwise. @@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: KVCacheSpec, - available_memory: int, - num_layers: int) -> KVCacheConfig: + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_spec: The kv cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes. - num_layers: The number of layers in the model. Returns: The generated KVCacheConfig @@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, assert len(page_sizes) == 1 page_size = page_sizes.pop() - num_blocks = int(available_memory // page_size // num_layers) + num_blocks = int(available_memory // page_size // len(kv_cache_spec)) num_blocks = max(num_blocks, 0) if vllm_config.cache_config.num_gpu_blocks_override is not None: @@ -541,6 +570,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, max_model_len_str, max_concurrency) per_layer_size = page_size * num_blocks + # All layers have the same KV cache spec, so we create one kv cache group + # for all layers. + grouped_layer_names = [list(kv_cache_spec.keys())] kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -548,41 +580,69 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, layer_name: KVCacheTensor(size=per_layer_size) for layer_name in kv_cache_spec }, - groups=[[layer_name for layer_name in kv_cache_spec]], - kv_cache_spec=kv_cache_spec) + kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, + grouped_layer_names), + ) return kv_cache_config -def get_kv_cache_configs(vllm_config: VllmConfig, - kv_cache_specs: list[KVCacheSpec], - available_memory: int) -> list[KVCacheConfig]: +def get_kv_cache_config(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model TODO: support hybrid models with more than one type of KV cache. Args: vllm_config: The global VllmConfig - kv_cache_specs: The kv cache specs of the model + kv_cache_spec: The kv cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes. Returns: The generated KVCacheConfigs """ - # Use the max number of layers to conservatively determine - # the number of blocks. - num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs) - kv_cache_configs = [] - for kv_cache_spec in kv_cache_specs: - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, - available_memory) - if is_kv_cache_type_uniform(kv_cache_spec): - # KV cache of all layers are the same, which is true for - # most models. Allocate the same amount of memory for - # each layer. - kv_cache_configs.append( - _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory, - num_layers)) - else: - raise NotImplementedError + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + + raise NotImplementedError + + +def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): + """ + Make the KV cache configurations for each worker consistent, so that all + workers can be controlled by the same KVCacheManager. + This function verifies that the layer group of each worker are the same, + and changes the num_blocks of each worker to the smallest among all workers. + + Args: + kv_cache_configs: The KV cache configurations for each worker. Will be + in-place modified to make them consistent. + """ + + # Sort the kv cache groups by the type_id of their KV cache spec. + # This can avoid the inconsistency caused by the order of groups. + for kv_cache_config in kv_cache_configs: + kv_cache_config.kv_cache_groups.sort( + key=lambda x: x.kv_cache_spec.type_id) + + # Verify that the groups of each rank are the same. + for kv_cache_config in kv_cache_configs[1:]: + for group_rank_0, group_rank_i in zip( + kv_cache_configs[0].kv_cache_groups, + kv_cache_config.kv_cache_groups): + assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec + + # Change the num_blocks of each rank to the smallest among all ranks. We + # do not need to shrink the tensor size because it is valid to only use the + # first `num_blocks` blocks of the tensor. + min_num_blocks = min(kv_cache_config.num_blocks + for kv_cache_config in kv_cache_configs) + for kv_cache_config in kv_cache_configs: + kv_cache_config.num_blocks = min_num_blocks + return kv_cache_configs diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 1598e6b8443..f4bb4583bea 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -21,7 +21,8 @@ maybe_register_config_serialize_by_value) from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, zmq_socket_ctx) -from vllm.v1.core.kv_cache_utils import get_kv_cache_configs +from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, + unify_kv_cache_configs) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, @@ -120,15 +121,27 @@ def _initialize_kv_caches(self, # memory can be allocated for kv cache. available_gpu_memory = self.model_executor.determine_available_memory() + assert len(kv_cache_specs) == len(available_gpu_memory) # Get the kv cache tensor size - kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, - available_gpu_memory) - num_gpu_blocks_set = set(config.num_blocks - for config in kv_cache_configs) - assert len(num_gpu_blocks_set) == 1, ( - f"num_gpu_blocks need to be the same across workers, " - f"but they are different: {num_gpu_blocks_set}") - num_gpu_blocks = num_gpu_blocks_set.pop() + kv_cache_configs = [ + get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, + available_gpu_memory_one_worker) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker in + zip(kv_cache_specs, available_gpu_memory) + ] + + # Since we use a shared centralized controller, we need the + # `kv_cache_config` to be consistent across all workers to make sure + # all the memory operators can be applied to all workers. + unify_kv_cache_configs(kv_cache_configs) + + # All workers have the same kv_cache_config except layer names, so use + # an arbitrary one to get the number of blocks. + assert all([ + cfg.num_blocks == kv_cache_configs[0].num_blocks + for cfg in kv_cache_configs + ]) + num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 # Initialize kv cache and warmup the execution diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index aa6ae83c26e..e3a4cd98c1f 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -62,14 +62,11 @@ def initialize_from_config(self, args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") - def determine_available_memory(self) -> int: # in bytes + def determine_available_memory(self) -> list[int]: # in bytes output = self.collective_rpc("determine_available_memory") - # Since we use a shared centralized controller, we take the minimum - # memory size across all workers to make sure all the memory - # operators can be applied to all workers. - return min(output) + return output - def get_kv_cache_specs(self) -> list[KVCacheSpec]: + def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: output = self.collective_rpc("get_kv_cache_spec") return output @@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - def determine_available_memory(self) -> int: # in bytes + def determine_available_memory(self) -> list[int]: # in bytes # same as determine_num_available_blocks in v0, # we need to get the min across all ranks. memory = super().determine_available_memory() @@ -103,4 +100,4 @@ def determine_available_memory(self) -> int: # in bytes cpu_group = get_world_group().cpu_group memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - return memory_tensor.item() + return [memory_tensor.item()] diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 1f885c10c8c..867b1b61c87 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,7 +11,7 @@ @dataclass -class KVCacheSpecBase: +class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. """ @@ -55,7 +55,7 @@ def bytes_for_tokens(self, num_tokens: int) -> int: @dataclass -class FullAttentionSpec(KVCacheSpecBase): +class FullAttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype @@ -76,9 +76,6 @@ def bytes_for_tokens(self, num_tokens: int) -> int: return cdiv(num_tokens, self.block_size) * self.page_size_bytes -KVCacheSpec = dict[str, KVCacheSpecBase] - - @dataclass class KVCacheTensor: """ @@ -89,6 +86,18 @@ class KVCacheTensor: size: int # The size of KV cache Tensor in bytes +@dataclass +class KVCacheGroupSpec: + """ + Represents a group of model layers that share the same KV cache block table. + These layers are regarded as one layer in the KV cache manager. + """ + # The names of model layers in this group + layer_names: list[str] + # The KV cache spec of this manager layer + kv_cache_spec: KVCacheSpec + + @dataclass class KVCacheConfig: """ @@ -99,17 +108,24 @@ class KVCacheConfig: """layer_name -> how to initialize KV cache for that layer""" tensors: dict[str, KVCacheTensor] """ - A list of kv-cache groups. Each group includes a set of layers with - the same kv-cache spec, and the total page_size of layers inside a group - is same across all groups (as the KVCacheManager only supports allocating - pages of the same size). For example: - 1. A model only uses full attention: one group with all layers in the model. - 2. (not implemented yet) A model with the same number of full attention - layers and sliding window attention layers: two groups, one for full - attention layers and one for sliding window attention layers. - 3. (not implemented yet) A model with 2 full attention layers and 4 sliding - window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). + The kv cache groups of the model. + The layers in the models are repeated with some patterns, e.g., a model + with 10 full attention layers and 20 sliding window attention layers can be + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + The KVCacheManager allocates different block tables for each of the 3 layers + in the pattern, and repeats each of them 10 times to generate the + block_table for the 30 layers in the model. + Therefore, we can group the layers in the model into 3 groups, each of which + contains 10 layers in the model. + The KVCacheManager allocates the block_table for each group based on its + kv_cache spec, and the model runner applies the block table to each layer + in the group. + For example: + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. + 2. (WIP) A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + there are 3 groups, each of which represents 10 layers in the model. """ - groups: list[list[str]] - """the KVCacheSpec of the model""" - kv_cache_spec: KVCacheSpec + kv_cache_groups: list[KVCacheGroupSpec] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b186300a003..229849e4439 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1510,34 +1510,46 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.groups) > 1: + if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: dict[str, torch.Tensor] = {} - for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % layer_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // layer_spec.page_size_bytes - if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, - layer_spec.head_size) - dtype = layer_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - else: - raise NotImplementedError + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. + assert num_blocks >= kv_cache_config.num_blocks + if isinstance(kv_cache_spec, FullAttentionSpec): + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + kv_caches[layer_name] = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + else: + # TODO: add new branches when introducing more types of + # KV cache specs. + raise ValueError("Unknown KV cache spec type.") bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. @@ -1549,7 +1561,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla - kv_cache_spec: KVCacheSpec = {} + kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): if isinstance(attn_module, FusedMoE): continue diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a63a2d02237..51b9f567396 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -185,7 +185,7 @@ def determine_available_memory(self) -> int: return int(available_kv_cache_memory) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ec3dcbc064c..fa6b29e94e3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -304,7 +304,7 @@ def get_model(self) -> nn.Module: assert self.model is not None return self.model - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. @@ -315,7 +315,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size - kv_cache_spec: KVCacheSpec = {} + kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. @@ -818,31 +818,33 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.groups) > 1: + if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: dict[str, torch.Tensor] = {} - for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % layer_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // layer_spec.page_size_bytes - if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, - layer_spec.head_size) - dtype = layer_spec.dtype - - tpu_k_cache = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - tpu_v_cache = torch.zeros_like(tpu_k_cache) - - kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) - else: - raise NotImplementedError + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + if isinstance(kv_cache_spec, FullAttentionSpec): + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + tpu_k_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + + kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache) + else: + raise NotImplementedError bind_kv_cache( kv_caches, diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index dbb231950d0..d56c25dd9da 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -189,7 +189,7 @@ def compile_or_warm_up_model(self) -> None: def get_model(self) -> nn.Module: return self.model_runner.get_model() - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 51d2da2344b..487a49b6211 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -51,7 +51,7 @@ def __init__( self.device: Optional[torch.device] = None self.model_runner: Optional[nn.Module] = None - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """Get specifications for KV cache implementation.""" raise NotImplementedError