Skip to content

Commit f214340

Browse files
committed
[V1] Support cross-layer KV sharing
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent fadb8d5 commit f214340

File tree

10 files changed

+458
-51
lines changed

10 files changed

+458
-51
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import copy
23
import importlib
34

45
import pytest
@@ -15,12 +16,13 @@
1516
PrefixCachingMetrics,
1617
estimate_max_model_len,
1718
generate_block_hash_extra_keys,
19+
get_kv_cache_config,
1820
hash_block_tokens,
1921
hash_request_tokens,
2022
unify_kv_cache_configs)
2123
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
22-
KVCacheGroupSpec, KVCacheTensor,
23-
SlidingWindowSpec)
24+
KVCacheGroupSpec, KVCacheSpec,
25+
KVCacheTensor, SlidingWindowSpec)
2426
from vllm.v1.metrics.stats import PrefixCacheStats
2527
from vllm.v1.request import Request
2628

@@ -557,13 +559,76 @@ def test_merge_kv_cache_spec():
557559
assert merged_layer_spec.sliding_window == 1
558560

559561

562+
def test_get_kv_cache_config_cross_layer_kv_sharing():
563+
# Create a VllmConfig
564+
model_id = "Qwen/Qwen1.5-7B"
565+
max_model_len = 16383
566+
model_config = ModelConfig(
567+
model_id,
568+
task="generate",
569+
tokenizer=model_id,
570+
tokenizer_mode="auto",
571+
trust_remote_code=False,
572+
seed=0,
573+
dtype="float16",
574+
max_model_len=max_model_len,
575+
)
576+
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
577+
578+
vllm_config = VllmConfig(
579+
model_config=model_config,
580+
scheduler_config=scheduler_config,
581+
)
582+
583+
# Create KV cache specs
584+
585+
# max memory usage bytes calculated as:
586+
# 1024 * 2 * 16 * 32 * 128 * 2
587+
spec = FullAttentionSpec(
588+
block_size=16,
589+
num_kv_heads=32,
590+
head_size=128,
591+
dtype=torch.float16,
592+
use_mla=False,
593+
)
594+
assert spec.max_memory_usage_bytes(vllm_config) == 268435456
595+
assert spec.page_size_bytes == 262144
596+
597+
# layer_1 shares KV cache with layer_0
598+
spec_shared_0 = copy.copy(spec)
599+
spec_shared_0.kv_sharing_target_layer_idx = 0
600+
assert spec_shared_0.max_memory_usage_bytes(vllm_config) == 0
601+
602+
kv_cache_spec: dict[str, KVCacheSpec] = {
603+
"layer_0": spec,
604+
"layer_1": spec_shared_0,
605+
}
606+
607+
available_memory = 268435456
608+
609+
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
610+
available_memory)
611+
assert kv_cache_config.num_blocks == 1024
612+
assert kv_cache_config.tensors["layer_0"].size == available_memory
613+
assert kv_cache_config.tensors["layer_1"].size == 0
614+
assert len(kv_cache_config.kv_cache_groups) == 1
615+
kv_sharing_layer_mapping = (
616+
kv_cache_config.kv_cache_groups[0].kv_sharing_layer_mapping)
617+
assert kv_sharing_layer_mapping is not None
618+
assert len(kv_sharing_layer_mapping) == 1
619+
assert kv_sharing_layer_mapping['layer_1'] == 0
620+
621+
560622
@pytest.mark.parametrize(
561-
("model_id", "max_model_len", "want_estimated_max_len"), [
562-
("Qwen/Qwen1.5-7B", 16385, 16384),
563-
("Qwen/Qwen1.5-7B", 16383, 16383),
564-
])
623+
("model_id", "max_model_len", "want_estimated_max_len",
624+
"kv_sharing_factor", "available_mem_gb"), [
625+
("Qwen/Qwen1.5-7B", 16385, 16384, 0, 8),
626+
("Qwen/Qwen1.5-7B", 16383, 16383, 0, 8),
627+
("Qwen/Qwen1.5-7B", 16383, 16383, 2, 4),
628+
])
565629
def test_estimate_max_model_len(model_id, max_model_len,
566-
want_estimated_max_len):
630+
want_estimated_max_len, kv_sharing_factor,
631+
available_mem_gb):
567632
# Create a VllmConfig
568633
model_config = ModelConfig(
569634
model_id,
@@ -585,17 +650,27 @@ def test_estimate_max_model_len(model_id, max_model_len,
585650
# Create KV cache specs
586651
kv_cache_spec = {}
587652
for i in range(32):
653+
kv_sharing_target_layer_idx = None
654+
if kv_sharing_factor > 0:
655+
share_kv = (i + 1) % kv_sharing_factor == 0
656+
if share_kv:
657+
# layer idx 1 will use KV cache from idx 0, etc
658+
kv_sharing_target_layer_idx = i - (kv_sharing_factor - 1)
659+
588660
layer_name = f"layer_{i}"
589-
kv_cache_spec[layer_name] = FullAttentionSpec(
661+
spec = FullAttentionSpec(
590662
block_size=16,
591663
num_kv_heads=32,
592664
head_size=128,
593665
dtype=torch.float16,
594666
use_mla=False,
595667
)
596-
# Estimate the maximum model length, 16384 model_len need 8GB
668+
spec.kv_sharing_target_layer_idx = kv_sharing_target_layer_idx
669+
kv_cache_spec[layer_name] = spec
670+
# Estimate the maximum model length, 16384 model_len need 8GB normally
671+
# with cross-layer KV sharing with sharing factor=2, we only need 4GB
597672
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
598-
8 * GiB_bytes)
673+
available_mem_gb * GiB_bytes)
599674
assert estimated_max_len == want_estimated_max_len
600675

601676

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33

44
import pytest
55

6+
from vllm.attention import Attention
67
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
78
from vllm.sampling_params import SamplingParams
89
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
910
SchedulerOutput)
11+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
12+
KVCacheGroupSpec, KVCacheTensor)
1013
from vllm.v1.worker.tpu_model_runner import (
1114
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
1215
_get_padded_token_len, _get_req_paddings, _get_token_paddings)
@@ -292,6 +295,81 @@ def test_update_states_request_unscheduled(model_runner):
292295
assert not _is_req_scheduled(model_runner, req_ids[1])
293296

294297

298+
def test_init_kv_cache_shared_valid(model_runner):
299+
spec = FullAttentionSpec(
300+
block_size=16,
301+
num_kv_heads=model_runner.model_config.get_num_kv_heads(
302+
model_runner.parallel_config),
303+
head_size=model_runner.model_config.get_head_size(),
304+
dtype=model_runner.kv_cache_dtype,
305+
use_mla=False,
306+
)
307+
kv_cache_config = KVCacheConfig(
308+
num_blocks=10,
309+
tensors={
310+
"layer.0": KVCacheTensor(size=spec.page_size_bytes * 12),
311+
"layer.1": KVCacheTensor(size=0),
312+
},
313+
kv_cache_groups=[
314+
KVCacheGroupSpec(
315+
# intentionally switch order to check layer names are sorted
316+
# such that layers that reuse KV cache from earlier layers
317+
# are processed after all layers that allocate KV cache
318+
layer_names=["layer.1", "layer.0"],
319+
kv_cache_spec=spec,
320+
kv_sharing_layer_mapping={"layer.1": 0}),
321+
])
322+
323+
fwd_context = (
324+
model_runner.vllm_config.compilation_config.static_forward_context)
325+
# populate forward context before init kv
326+
fwd_context['layer.0'] = Attention(32, 128, 0.1)
327+
fwd_context['layer.1'] = Attention(32, 128, 0.1)
328+
329+
model_runner.initialize_kv_cache(kv_cache_config)
330+
331+
# check memory references of KV caches for layer 0 and 1 are the same
332+
assert id(model_runner.kv_caches[0]) == id(model_runner.kv_caches[1])
333+
assert len(fwd_context["layer.0"].kv_cache) == 1
334+
assert len(fwd_context["layer.1"].kv_cache) == 1
335+
layer_1_kv_cache = fwd_context["layer.1"].kv_cache[0]
336+
layer_2_kv_cache = fwd_context["layer.1"].kv_cache[0]
337+
assert id(layer_1_kv_cache) == id(layer_2_kv_cache)
338+
339+
340+
@pytest.mark.parametrize("target_layer_idx", [1, 2])
341+
def test_init_kv_cache_shared_invalid(model_runner, target_layer_idx):
342+
spec = FullAttentionSpec(
343+
block_size=16,
344+
num_kv_heads=model_runner.model_config.get_num_kv_heads(
345+
model_runner.parallel_config),
346+
head_size=model_runner.model_config.get_head_size(),
347+
dtype=model_runner.kv_cache_dtype,
348+
use_mla=False,
349+
)
350+
kv_cache_config = KVCacheConfig(
351+
num_blocks=10,
352+
tensors={
353+
"layer.1": KVCacheTensor(size=spec.page_size_bytes * 12),
354+
"layer.0": KVCacheTensor(size=0),
355+
},
356+
kv_cache_groups=[
357+
KVCacheGroupSpec(
358+
layer_names=["layer.1", "layer.0"],
359+
kv_cache_spec=spec,
360+
kv_sharing_layer_mapping={"layer.0": target_layer_idx}),
361+
])
362+
363+
if target_layer_idx >= 2:
364+
error_msg = "2 is an invalid layer index!"
365+
else:
366+
error_msg = ("layer.0 cannot share KV cache with layer.1 which comes"
367+
" after it")
368+
369+
with pytest.raises(AssertionError, match=error_msg):
370+
model_runner.initialize_kv_cache(kv_cache_config)
371+
372+
295373
def test_get_paddings():
296374
# Bucketed padding
297375
min_token_size, max_token_size, padding_gap = 16, 512, 64

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from vllm.attention import Attention
56
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
67
SchedulerConfig, VllmConfig)
78
from vllm.sampling_params import SamplingParams
@@ -48,8 +49,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
4849
runner.initialize_attn_backend(kv_cache_config)
4950

5051

51-
@pytest.fixture
52-
def model_runner():
52+
def init_model_runner():
5353
scheduler_config = SchedulerConfig(
5454
max_num_seqs=10,
5555
max_num_batched_tokens=512,
@@ -80,6 +80,17 @@ def model_runner():
8080

8181
device = "cuda"
8282
runner = GPUModelRunner(vllm_config, device)
83+
return runner
84+
85+
86+
@pytest.fixture(autouse=True)
87+
def model_runner(request):
88+
runner = init_model_runner()
89+
90+
if 'skipkvinit' in request.keywords:
91+
# do not init kv cache for specific tests
92+
return runner
93+
8394
initialize_kv_cache(runner)
8495
return runner
8596

@@ -321,3 +332,80 @@ def test_update_states_request_unscheduled(model_runner):
321332

322333
assert _is_req_added(model_runner, req_ids[1])
323334
assert not _is_req_scheduled(model_runner, req_ids[1])
335+
336+
337+
@pytest.mark.skipkvinit
338+
def test_init_kv_cache_shared_valid(model_runner):
339+
spec = FullAttentionSpec(
340+
block_size=16,
341+
num_kv_heads=model_runner.model_config.get_num_kv_heads(
342+
model_runner.parallel_config),
343+
head_size=model_runner.model_config.get_head_size(),
344+
dtype=model_runner.kv_cache_dtype,
345+
use_mla=False,
346+
)
347+
kv_cache_config = KVCacheConfig(
348+
num_blocks=10,
349+
tensors={
350+
"layer.0": KVCacheTensor(size=spec.page_size_bytes * 12),
351+
"layer.1": KVCacheTensor(size=0),
352+
},
353+
kv_cache_groups=[
354+
KVCacheGroupSpec(
355+
# intentionally switch order to check layer names are sorted
356+
# such that layers that reuse KV cache from earlier layers
357+
# are processed after all layers that allocate KV cache
358+
layer_names=["layer.1", "layer.0"],
359+
kv_cache_spec=spec,
360+
kv_sharing_layer_mapping={"layer.1": 0}),
361+
])
362+
363+
fwd_context = (
364+
model_runner.vllm_config.compilation_config.static_forward_context)
365+
# populate forward context before init kv
366+
fwd_context['layer.0'] = Attention(32, 128, 0.1)
367+
fwd_context['layer.1'] = Attention(32, 128, 0.1)
368+
369+
model_runner.initialize_kv_cache(kv_cache_config)
370+
371+
# check memory references of KV caches for layer 0 and 1 are the same
372+
assert id(model_runner.kv_caches[0]) == id(model_runner.kv_caches[1])
373+
assert len(fwd_context["layer.0"].kv_cache) == 1
374+
assert len(fwd_context["layer.1"].kv_cache) == 1
375+
layer_1_kv_cache = fwd_context["layer.1"].kv_cache[0]
376+
layer_2_kv_cache = fwd_context["layer.1"].kv_cache[0]
377+
assert id(layer_1_kv_cache) == id(layer_2_kv_cache)
378+
379+
380+
@pytest.mark.skipkvinit
381+
@pytest.mark.parametrize("target_layer_idx", [1, 2])
382+
def test_init_kv_cache_shared_invalid(model_runner, target_layer_idx):
383+
spec = FullAttentionSpec(
384+
block_size=16,
385+
num_kv_heads=model_runner.model_config.get_num_kv_heads(
386+
model_runner.parallel_config),
387+
head_size=model_runner.model_config.get_head_size(),
388+
dtype=model_runner.kv_cache_dtype,
389+
use_mla=False,
390+
)
391+
kv_cache_config = KVCacheConfig(
392+
num_blocks=10,
393+
tensors={
394+
"layer.1": KVCacheTensor(size=spec.page_size_bytes * 12),
395+
"layer.0": KVCacheTensor(size=0),
396+
},
397+
kv_cache_groups=[
398+
KVCacheGroupSpec(
399+
layer_names=["layer.1", "layer.0"],
400+
kv_cache_spec=spec,
401+
kv_sharing_layer_mapping={"layer.0": target_layer_idx}),
402+
])
403+
404+
if target_layer_idx >= 2:
405+
error_msg = "2 is an invalid layer index!"
406+
else:
407+
error_msg = ("layer.0 cannot share KV cache with layer.1 which comes"
408+
" after it")
409+
410+
with pytest.raises(AssertionError, match=error_msg):
411+
model_runner.initialize_kv_cache(kv_cache_config)

vllm/attention/layer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
use_mla: bool = False,
5050
prefix: str = "",
5151
attn_type: str = AttentionType.DECODER,
52+
kv_sharing_target_layer_idx: Optional[int] = None,
5253
**extra_impl_args,
5354
) -> None:
5455
"""
@@ -102,6 +103,10 @@ def __init__(
102103
self.head_size = head_size
103104
self.num_kv_heads = num_kv_heads
104105
self.sliding_window = sliding_window
106+
self.kv_sharing_target_layer_idx = kv_sharing_target_layer_idx
107+
if kv_sharing_target_layer_idx is not None:
108+
extra_impl_args['kv_sharing_target_layer_idx'] = (
109+
kv_sharing_target_layer_idx)
105110

106111
quant_method = quant_config.get_quant_method(
107112
self, prefix=prefix) if quant_config else None

0 commit comments

Comments
 (0)