Skip to content

Commit 9e07c36

Browse files
committed
[V1] Support cross-layer KV sharing
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent a3896c7 commit 9e07c36

31 files changed

+581
-72
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33

44
import pytest
55

6+
from vllm.attention.layer import Attention
67
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
78
from vllm.sampling_params import SamplingParams
9+
from vllm.utils import GiB_bytes
10+
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
11+
get_kv_cache_config)
812
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
913
SchedulerOutput)
1014
from vllm.v1.worker.tpu_model_runner import (
@@ -341,3 +345,165 @@ def test_get_req_paddings():
341345
assert _get_req_paddings(1, 32) == [8, 16, 32]
342346
assert _get_req_paddings(8, 32) == [8, 16, 32]
343347
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
348+
349+
350+
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
351+
layer_0 = "model.layers.0.self_attn.attn"
352+
layer_1 = "model.layers.1.self_attn.attn"
353+
invalid_layer = "model.layers.0.cross_attn.attn"
354+
fwd_context = {
355+
layer_0:
356+
Attention(
357+
num_heads=8,
358+
head_size=64,
359+
scale=1.0,
360+
prefix=layer_0,
361+
),
362+
layer_1:
363+
Attention(
364+
num_heads=8,
365+
head_size=64,
366+
scale=1.0,
367+
prefix=layer_1,
368+
# invalid layer: cross_attn.atn doesn't exist!
369+
kv_sharing_target_layer_name=invalid_layer,
370+
)
371+
}
372+
model_runner.vllm_config.compilation_config.static_forward_context = (
373+
fwd_context)
374+
error_msg = f"{invalid_layer} is not a Attention layer in the model"
375+
with pytest.raises(ValueError, match=error_msg):
376+
model_runner.get_kv_cache_spec()
377+
378+
379+
def test_init_kv_cache_without_kv_sharing(model_runner):
380+
layer_0 = "model.layers.0.self_attn.attn"
381+
layer_1 = "model.layers.1.self_attn.attn"
382+
fwd_context = {
383+
layer_0: Attention(
384+
num_heads=8,
385+
head_size=64,
386+
scale=1.0,
387+
),
388+
layer_1: Attention(
389+
num_heads=8,
390+
head_size=64,
391+
scale=1.0,
392+
)
393+
}
394+
vllm_config = model_runner.vllm_config
395+
# Set high context length to test max context length estimation
396+
model_runner.vllm_config.model_config.max_model_len = 3_000_000
397+
# Hacky way to initialize forward context without initializing the model.
398+
model_runner.vllm_config.compilation_config.static_forward_context = (
399+
fwd_context)
400+
vllm_ctx = (
401+
model_runner.vllm_config.compilation_config.static_forward_context)
402+
kv_cache_spec = model_runner.get_kv_cache_spec()
403+
assert len(kv_cache_spec) == 2
404+
assert len(model_runner.shared_kv_cache_layers) == 0
405+
406+
available_memory = 20 * GiB_bytes
407+
# page size for layer 0's kv_cache_spec is 32KB
408+
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
409+
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
410+
available_memory)
411+
assert kv_cache_config.num_blocks == num_expected_blocks
412+
assert len(kv_cache_config.tensors) == 2
413+
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
414+
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
415+
416+
max_context_len =\
417+
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
418+
# max context len with KV sharing should be 2x as large as without
419+
assert max_context_len == 1310720
420+
421+
# important: override tensor size to prevent large mem alloc during test
422+
# this will only allocate 2 block worth of memory (2 * 32kb)
423+
kv_cache_config.num_blocks = 1
424+
for layer in kv_cache_config.tensors:
425+
kv_cache_config.tensors[layer].size =\
426+
kv_cache_spec[layer].page_size_bytes
427+
428+
model_runner.initialize_kv_cache(kv_cache_config)
429+
430+
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
431+
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
432+
# check layer 1 kv cache does NOT share memory with layer 0
433+
assert id(layer_1_kv) != id(layer_0_kv)
434+
435+
# check layer 1 added to kv cache group's layer names
436+
assert len(kv_cache_config.kv_cache_groups) == 1
437+
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
438+
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
439+
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
440+
441+
442+
def test_init_kv_cache_with_kv_sharing_valid(model_runner):
443+
layer_0 = "model.layers.0.self_attn.attn"
444+
layer_1 = "model.layers.1.self_attn.attn"
445+
fwd_context = {
446+
layer_0:
447+
Attention(
448+
num_heads=8,
449+
head_size=64,
450+
scale=1.0,
451+
prefix=layer_0,
452+
),
453+
layer_1:
454+
Attention(
455+
num_heads=8,
456+
head_size=64,
457+
scale=1.0,
458+
prefix=layer_1,
459+
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
460+
)
461+
}
462+
vllm_config = model_runner.vllm_config
463+
# Set high context length to test max context length estimation
464+
model_runner.vllm_config.model_config.max_model_len = 3_000_000
465+
model_runner.vllm_config.compilation_config.static_forward_context = (
466+
fwd_context)
467+
vllm_ctx = (
468+
model_runner.vllm_config.compilation_config.static_forward_context)
469+
kv_cache_spec = model_runner.get_kv_cache_spec()
470+
assert len(kv_cache_spec) == 1
471+
assert layer_0 in kv_cache_spec
472+
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
473+
474+
available_memory = 20 * GiB_bytes
475+
# page size for layer 0's kv_cache_spec is 32KB
476+
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
477+
# which is twice as many as without KV sharing
478+
num_expected_blocks = 655360 # 20GB / 32KB
479+
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
480+
available_memory)
481+
assert kv_cache_config.num_blocks == num_expected_blocks
482+
assert len(kv_cache_config.tensors) == 1
483+
# Each layer now has twice the available memory for KV cache
484+
# compared to no KV sharing
485+
assert kv_cache_config.tensors[layer_0].size == available_memory
486+
487+
max_context_len =\
488+
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
489+
# max context len with KV sharing should be 2x as large as without
490+
assert max_context_len == 2 * 1310720
491+
492+
# important: override tensor size to prevent large mem alloc during test
493+
# this will only allocate 1 block worth of memory (32kb)
494+
kv_cache_config.num_blocks = 1
495+
kv_cache_config.tensors[layer_0].size =\
496+
kv_cache_spec[layer_0].page_size_bytes
497+
498+
model_runner.initialize_kv_cache(kv_cache_config)
499+
500+
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
501+
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
502+
# check layer 1 kv cache shares memory with layer 0
503+
assert id(layer_1_kv) == id(layer_0_kv)
504+
505+
# check layer 1 added to kv cache group's layer names
506+
assert len(kv_cache_config.kv_cache_groups) == 1
507+
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
508+
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
509+
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1

0 commit comments

Comments
 (0)