Skip to content

Commit a9f598d

Browse files
markmchuachenheli
authored andcommitted
Revert "[v1] Support multiple KV cache groups in GPU model runner (vllm-project#17945) (vllm-project#18459)
Signed-off-by: Mark McLoughlin <[email protected]> Signed-off-by: Chenheli Hua <[email protected]>
1 parent 45007b3 commit a9f598d

15 files changed

+214
-481
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 3 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
hash_request_tokens,
2020
unify_kv_cache_configs)
2121
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
22-
KVCacheGroupSpec, KVCacheTensor,
23-
SlidingWindowSpec)
22+
KVCacheGroupSpec, KVCacheTensor)
2423
from vllm.v1.metrics.stats import PrefixCacheStats
2524
from vllm.v1.request import Request
2625

@@ -55,14 +54,12 @@ def new_kv_cache_spec(block_size=16,
5554
num_kv_heads=2,
5655
head_size=64,
5756
dtype=torch.float32,
58-
use_mla=False,
59-
sliding_window=None):
57+
use_mla=False):
6058
return FullAttentionSpec(block_size=block_size,
6159
num_kv_heads=num_kv_heads,
6260
head_size=head_size,
6361
dtype=dtype,
64-
use_mla=use_mla,
65-
sliding_window=sliding_window)
62+
use_mla=use_mla)
6663

6764

6865
def test_none_hash(monkeypatch):
@@ -495,68 +492,6 @@ def test_unify_kv_cache_configs():
495492
unify_kv_cache_configs(diff_kv_cache_config)
496493

497494

498-
def test_merge_kv_cache_spec():
499-
same_layer_specs = [
500-
new_kv_cache_spec(num_kv_heads=32),
501-
new_kv_cache_spec(num_kv_heads=32),
502-
]
503-
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
504-
assert merged_layer_spec.block_size == 16
505-
assert merged_layer_spec.num_kv_heads == 32
506-
assert merged_layer_spec.head_size == 64
507-
assert merged_layer_spec.dtype == torch.float32
508-
assert merged_layer_spec.sliding_window is None
509-
510-
different_layer_specs = [
511-
new_kv_cache_spec(num_kv_heads=32),
512-
new_kv_cache_spec(num_kv_heads=16),
513-
]
514-
with pytest.raises(AssertionError):
515-
different_layer_specs[0].merge(different_layer_specs)
516-
517-
full_spec = new_kv_cache_spec(num_kv_heads=32)
518-
different_type_layer_specs = [
519-
full_spec,
520-
SlidingWindowSpec(
521-
block_size=full_spec.block_size,
522-
num_kv_heads=full_spec.num_kv_heads,
523-
head_size=full_spec.head_size,
524-
dtype=full_spec.dtype,
525-
use_mla=full_spec.use_mla,
526-
sliding_window=1,
527-
),
528-
]
529-
with pytest.raises(AssertionError):
530-
different_type_layer_specs[0].merge(different_type_layer_specs)
531-
with pytest.raises(AssertionError):
532-
different_type_layer_specs[1].merge(different_type_layer_specs)
533-
534-
different_sliding_window_layer_specs = [
535-
new_kv_cache_spec(num_kv_heads=32),
536-
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
537-
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
538-
]
539-
with pytest.raises(ValueError):
540-
different_sliding_window_layer_specs[0].merge(
541-
different_sliding_window_layer_specs)
542-
543-
same_sliding_window_layer_specs = [
544-
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
545-
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
546-
]
547-
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
548-
same_sliding_window_layer_specs)
549-
assert merged_layer_spec.sliding_window == 1
550-
551-
same_sliding_window_layer_spec_with_none = [
552-
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
553-
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
554-
]
555-
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
556-
same_sliding_window_layer_spec_with_none)
557-
assert merged_layer_spec.sliding_window == 1
558-
559-
560495
@pytest.mark.parametrize(
561496
("model_id", "max_model_len", "want_estimated_max_len"), [
562497
("Qwen/Qwen1.5-7B", 16385, 16384),

tests/v1/core/test_prefix_caching.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
8484
blocks = manager.allocate_slots(req0, 55,
8585
len(computed_blocks.blocks) * 16,
8686
computed_blocks)
87-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
87+
assert blocks.get_block_ids() == [1, 2, 3, 4]
8888

8989
# Check full block metadata
9090
parent_block_hash = None
@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
107107
req1 = make_request("1", common_token_ids + unique_token_ids)
108108
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
109109
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
110-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
110+
assert computed_blocks.get_block_ids() == [1, 2, 3]
111111
assert num_computed_tokens == 3 * 16
112112
num_new_tokens = 53 - 3 * 16
113113
blocks = manager.allocate_slots(req1, num_new_tokens,
114114
len(computed_blocks.blocks) * 16,
115115
computed_blocks)
116-
assert blocks.get_block_ids() == [[5]]
116+
assert blocks.get_block_ids() == [5]
117117
for block in computed_blocks.blocks:
118118
assert block.ref_cnt == 2
119119

@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
141141
req2 = make_request("2", common_token_ids + unique_token_ids)
142142
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
143143
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
144-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
144+
assert computed_blocks.get_block_ids() == [1, 2, 3]
145145
assert num_computed_tokens == 3 * 16
146146
num_new_tokens = 53 - 3 * 16
147147
blocks = manager.allocate_slots(req2, num_new_tokens,
148148
len(computed_blocks.blocks) * 16,
149149
computed_blocks)
150-
assert blocks.get_block_ids() == [[6]]
150+
assert blocks.get_block_ids() == [6]
151151

152152
# Although we only have 6 free blocks, we have 8 blocks in
153153
# the free block queue due to lazy removal.
@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
171171
len(computed_blocks.blocks) * 16,
172172
computed_blocks)
173173
# This block ID order also checks the eviction order.
174-
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
174+
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
175175
assert manager.block_pool.free_block_queue.num_free_blocks == 0
176176
assert manager.block_pool.free_block_queue.free_list_head is None
177177
assert manager.block_pool.free_block_queue.free_list_tail is None
@@ -208,7 +208,7 @@ def test_prefill_plp():
208208
blocks = manager.allocate_slots(req0, 55,
209209
len(computed_blocks.blocks) * 16,
210210
computed_blocks)
211-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
211+
assert blocks.get_block_ids() == [1, 2, 3, 4]
212212
req0_block_hashes = [b.block_hash for b in blocks.blocks]
213213

214214
# Check full block metadata
@@ -233,13 +233,13 @@ def test_prefill_plp():
233233
req1 = make_request("1", common_token_ids + unique_token_ids)
234234
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
235235
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
236-
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
236+
assert computed_blocks.get_block_ids() == [1, 2, 3]
237237
assert num_computed_tokens == 3 * 16
238238
num_new_tokens = 53 - 3 * 16
239239
blocks = manager.allocate_slots(req1, num_new_tokens,
240240
len(computed_blocks.blocks) * 16,
241241
computed_blocks)
242-
assert blocks.get_block_ids() == [[5]]
242+
assert blocks.get_block_ids() == [5]
243243
for block in computed_blocks.blocks:
244244
assert block.ref_cnt == 2
245245

@@ -277,11 +277,11 @@ def test_prefill_plp():
277277
block_ids = blocks.get_block_ids()
278278
# Duplicate cached blocks have different ids but same hashes vs request #0
279279
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
280-
assert block_ids != [[1, 2, 3, 4]]
280+
assert block_ids != [1, 2, 3, 4]
281281

282282
# Request #2 block hashes are valid since request #0 hashes are.
283283
# Check block reference counts.
284-
for block_id in block_ids[0]:
284+
for block_id in block_ids:
285285
assert manager.block_pool.blocks[block_id].ref_cnt == 1
286286

287287
manager.free(req2)
@@ -307,7 +307,7 @@ def test_decode():
307307
blocks = manager.allocate_slots(req0, 55,
308308
len(computed_blocks.blocks) * 16,
309309
computed_blocks)
310-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
310+
assert blocks.get_block_ids() == [1, 2, 3, 4]
311311

312312
# Append slots without allocating a new block.
313313
req0.num_computed_tokens = 55
@@ -379,12 +379,12 @@ def test_evict():
379379
# Touch the first 2 blocks.
380380
req2 = make_request("2", list(range(2 * 16 + 3)))
381381
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
382-
assert computed_blocks.get_block_ids() == [[1, 2]]
382+
assert computed_blocks.get_block_ids() == [1, 2]
383383
assert num_computed_tokens == 2 * 16
384384
blocks = manager.allocate_slots(req2, 3,
385385
len(computed_blocks.blocks) * 16,
386386
computed_blocks)
387-
assert blocks.get_block_ids() == [[10]]
387+
assert blocks.get_block_ids() == [10]
388388
assert manager.block_pool.free_block_queue.num_free_blocks == 7
389389

390390

@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
625625
blocks = manager.allocate_slots(req0, 59,
626626
len(computed_blocks.blocks) * 16,
627627
computed_blocks)
628-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
628+
assert blocks.get_block_ids() == [1, 2, 3, 4]
629629
req0.num_computed_tokens = 59
630630

631631
# Append slots without allocating a new block.
@@ -686,7 +686,7 @@ def test_cache_key_salting():
686686
blocks = manager.allocate_slots(req0, 59,
687687
len(computed_blocks.blocks) * 16,
688688
computed_blocks)
689-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
689+
assert blocks.get_block_ids() == [1, 2, 3, 4]
690690
req0.num_computed_tokens = 59
691691

692692
# Append slots without allocating a new block.
@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
797797
all_token_ids = full_block_token_ids + unique_token_ids
798798
req0 = make_request("0", all_token_ids)
799799
blocks = manager.allocate_slots(req0, 55)
800-
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
800+
assert blocks.get_block_ids() == [1, 2, 3, 4]
801801

802802
unique_token_ids = [4] * 7
803803
all_token_ids = full_block_token_ids + unique_token_ids
@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
808808
blocks = manager.allocate_slots(req1, 7,
809809
len(computed_blocks.blocks) * 16,
810810
computed_blocks)
811-
assert blocks.get_block_ids() == [[5]]
811+
assert blocks.get_block_ids() == [5]
812812

813813
# Failed to reset prefix cache because some blocks are not freed yet.
814814
assert not manager.reset_prefix_cache()

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99

1010
from vllm.sampling_params import SamplingParams
1111
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
12-
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
13-
KVCacheGroupSpec, KVCacheTensor)
1412
from vllm.v1.sample.metadata import SamplingMetadata
15-
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
16-
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
13+
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
14+
InputBatch)
1715

1816
VOCAB_SIZE = 1024
1917
NUM_OUTPUT_TOKENS = 20
@@ -24,27 +22,6 @@
2422
MAX_NUM_PROMPT_TOKENS = 64
2523

2624

27-
def get_kv_cache_config() -> KVCacheConfig:
28-
return KVCacheConfig(
29-
num_blocks=10,
30-
tensors={
31-
"layer.0": KVCacheTensor(size=1024),
32-
},
33-
kv_cache_groups=[
34-
KVCacheGroupSpec(
35-
layer_names=["layer.0"],
36-
kv_cache_spec=FullAttentionSpec(
37-
block_size=1,
38-
num_kv_heads=1,
39-
head_size=16,
40-
dtype=torch.float16,
41-
use_mla=False,
42-
),
43-
),
44-
],
45-
)
46-
47-
4825
def _compare_objs(obj1, obj2):
4926
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
5027
attr_names = set([
@@ -64,10 +41,6 @@ def _compare_objs(obj1, obj2):
6441
elif isinstance(a, np.ndarray):
6542
if np.allclose(a, b):
6643
is_same = True
67-
elif isinstance(a, MultiGroupBlockTable):
68-
for a_i, b_i in zip(a.block_tables, b.block_tables):
69-
_compare_objs(a_i, b_i)
70-
is_same = True
7144
elif isinstance(a, (BlockTable, SamplingMetadata)):
7245
_compare_objs(a, b)
7346
is_same = True # if we make it here must be same
@@ -225,7 +198,7 @@ def _construct_cached_request_state(req_id_suffix: int):
225198
sampling_params=_create_sampling_params(),
226199
mm_inputs=[],
227200
mm_positions=[],
228-
block_ids=[[]],
201+
block_ids=[],
229202
generator=None,
230203
num_computed_tokens=len(output_token_ids),
231204
output_token_ids=output_token_ids,
@@ -247,11 +220,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
247220
input_batch: InputBatch = InputBatch(
248221
max_num_reqs=batch_size,
249222
max_model_len=1024,
223+
max_num_blocks_per_req=10,
250224
max_num_batched_tokens=1024,
251225
device=torch.device(device),
252226
pin_memory=is_pin_memory_available(),
253227
vocab_size=1024,
254-
kv_cache_config=get_kv_cache_config(),
255228
)
256229
reqs: list[CachedRequestState] = []
257230
req_id_reqs = {}
@@ -337,20 +310,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
337310
input_batch: InputBatch = InputBatch(
338311
max_num_reqs=batch_size,
339312
max_model_len=1024,
313+
max_num_blocks_per_req=10,
340314
max_num_batched_tokens=1024,
341315
device=torch.device(device),
342316
pin_memory=is_pin_memory_available(),
343317
vocab_size=1024,
344-
kv_cache_config=get_kv_cache_config(),
345318
)
346319
ref_input_batch: InputBatch = InputBatch(
347320
max_num_reqs=batch_size,
348321
max_model_len=1024,
322+
max_num_blocks_per_req=10,
349323
max_num_batched_tokens=1024,
350324
device=torch.device(device),
351325
pin_memory=is_pin_memory_available(),
352326
vocab_size=1024,
353-
kv_cache_config=get_kv_cache_config(),
354327
)
355328

356329
reqs: list[CachedRequestState] = []

0 commit comments

Comments
 (0)