Skip to content

[v1] Redo "Support multiple KV cache groups in GPU model runner (#17945)" #18593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
hash_request_tokens,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request

Expand Down Expand Up @@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
use_mla=False,
sliding_window=None):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
use_mla=use_mla,
sliding_window=sliding_window)


def test_none_hash(monkeypatch):
Expand Down Expand Up @@ -492,6 +495,68 @@ def test_unify_kv_cache_configs():
unify_kv_cache_configs(diff_kv_cache_config)


def test_merge_kv_cache_spec():
same_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32),
]
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
assert merged_layer_spec.block_size == 16
assert merged_layer_spec.num_kv_heads == 32
assert merged_layer_spec.head_size == 64
assert merged_layer_spec.dtype == torch.float32
assert merged_layer_spec.sliding_window is None

different_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=16),
]
with pytest.raises(AssertionError):
different_layer_specs[0].merge(different_layer_specs)

full_spec = new_kv_cache_spec(num_kv_heads=32)
different_type_layer_specs = [
full_spec,
SlidingWindowSpec(
block_size=full_spec.block_size,
num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size,
dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1,
),
]
with pytest.raises(AssertionError):
different_type_layer_specs[0].merge(different_type_layer_specs)
with pytest.raises(AssertionError):
different_type_layer_specs[1].merge(different_type_layer_specs)

different_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
]
with pytest.raises(ValueError):
different_sliding_window_layer_specs[0].merge(
different_sliding_window_layer_specs)

same_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
]
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
same_sliding_window_layer_specs)
assert merged_layer_spec.sliding_window == 1

same_sliding_window_layer_spec_with_none = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
]
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
same_sliding_window_layer_spec_with_none)
assert merged_layer_spec.sliding_window == 1


@pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384),
Expand Down
36 changes: 18 additions & 18 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]

# Check full block metadata
parent_block_hash = None
Expand All @@ -107,13 +107,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2

Expand Down Expand Up @@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [6]
assert blocks.get_block_ids() == [[6]]

# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
Expand All @@ -171,7 +171,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0_block_hashes = [b.block_hash for b in blocks.blocks]

# Check full block metadata
Expand All @@ -233,13 +233,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2

Expand Down Expand Up @@ -277,11 +277,11 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4]
assert block_ids != [[1, 2, 3, 4]]

# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids:
for block_id in block_ids[0]:
assert manager.block_pool.blocks[block_id].ref_cnt == 1

manager.free(req2)
Expand All @@ -307,7 +307,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]

# Append slots without allocating a new block.
req0.num_computed_tokens = 55
Expand Down Expand Up @@ -379,12 +379,12 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2]
assert computed_blocks.get_block_ids() == [[1, 2]]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [10]
assert blocks.get_block_ids() == [[10]]
assert manager.block_pool.free_block_queue.num_free_blocks == 7


Expand Down Expand Up @@ -625,7 +625,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
Expand Down Expand Up @@ -686,7 +686,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
Expand Down Expand Up @@ -797,7 +797,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]

unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
Expand All @@ -808,7 +808,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]

# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
Expand Down
39 changes: 33 additions & 6 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
InputBatch)
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
Expand All @@ -22,6 +24,27 @@
MAX_NUM_PROMPT_TOKENS = 64


def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)


def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
Expand All @@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
Expand Down Expand Up @@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[],
block_ids=[[]],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
Expand All @@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
Expand Down Expand Up @@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)

reqs: list[CachedRequestState] = []
Expand Down
Loading