Skip to content

Commit 2ca014e

Browse files
committed
simplify specialized manager interface
Signed-off-by: Chen Zhang <[email protected]>
1 parent 818fb83 commit 2ca014e

File tree

4 files changed

+101
-161
lines changed

4 files changed

+101
-161
lines changed

tests/v1/core/test_specialized_manager.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import torch
44

55
from vllm.v1.core.block_pool import BlockPool
6-
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
7-
PrefixLengthRange)
6+
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
87
from vllm.v1.core.specialized_manager import SlidingWindowManager
98
from vllm.v1.kv_cache_interface import SlidingWindowSpec
109

@@ -22,35 +21,50 @@ def test_sliding_window_possible_cached_prefix():
2221
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
2322
manager = SlidingWindowManager(sliding_window_spec, block_pool)
2423

25-
block_is_cached = [
24+
def run_one_case(block_is_cached, expect_length):
25+
block_hash_list = [
26+
BlockHashType(i, ()) for i in range(len(block_is_cached))
27+
]
28+
29+
block_pool.cached_block_hash_to_block.clear()
30+
31+
# Mock the block pool with the cached blocks
32+
for i, (block_hash,
33+
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
34+
if is_cached:
35+
block_pool.cached_block_hash_to_block[block_hash] = {
36+
i: block_pool.blocks[i + 10]
37+
}
38+
39+
computed_blocks = manager.get_longest_cached_prefix(block_hash_list)
40+
assert len(computed_blocks) == expect_length
41+
42+
assert all(block == block_pool.get_null_block()
43+
for block in computed_blocks[:expect_length - 2])
44+
for i in range(2):
45+
if i < expect_length:
46+
block_index = expect_length - i - 1
47+
assert computed_blocks[
48+
block_index].block_id == block_index + 10
49+
50+
run_one_case([False] * 10, 0)
51+
run_one_case([True], 1)
52+
run_one_case([True, False], 1)
53+
run_one_case([True, True], 2)
54+
run_one_case([True, True, False], 2)
55+
run_one_case([True, True, True], 3)
56+
run_one_case([True, True, True, False], 3)
57+
run_one_case([
2658
True, True, False, True, False, False, True, True, False, True, True,
2759
True
28-
]
29-
block_hash_list = [
30-
BlockHashType(i, ()) for i in range(len(block_is_cached))
31-
]
32-
33-
# Mock the block pool with the cached blocks
34-
for i, (block_hash,
35-
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
36-
if is_cached:
37-
block_pool.cached_block_hash_to_block[block_hash] = {
38-
i: block_pool.blocks[i + 10]
39-
}
40-
41-
ranges, computed_blocks = manager.get_possible_cached_prefix(
42-
block_hash_list)
43-
assert ranges == [
44-
PrefixLengthRange(0, 4),
45-
PrefixLengthRange(16, 16),
46-
PrefixLengthRange(22, 24)
47-
]
48-
expected_computed_blocks = [
49-
block_pool.blocks[i +
50-
10] if is_cached else block_pool.get_null_block()
51-
for i, is_cached in enumerate(block_is_cached)
52-
]
53-
assert computed_blocks == expected_computed_blocks
60+
], 12)
61+
run_one_case([
62+
True, True, False, True, False, False, True, True, False, False, False
63+
], 8)
64+
run_one_case([
65+
True, True, False, True, False, False, True, True, False, False, False,
66+
True
67+
], 8)
5468

5569

5670
def test_sliding_window_remove_useless_blocks():
@@ -87,49 +101,39 @@ def assert_block_id(block_table, ids):
87101
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
88102
]
89103
block_table = id_to_block_table(original_block_ids)
90-
removed = manager.remove_useless_blocks(block_table, 0, is_first_call=True)
104+
removed = manager.remove_useless_blocks(block_table, 0)
91105
assert_block_id(removed, [])
92106
assert_block_id(block_table, original_block_ids)
93107

94108
# 5 tokens are computed. Only token 0 is out of the sliding window. As
95109
# block 1000 also contains token 1 that is in the sliding window, block 1000
96110
# cannot be removed.
97-
removed = manager.remove_useless_blocks(block_table,
98-
5,
99-
is_first_call=False)
111+
removed = manager.remove_useless_blocks(block_table, 5)
100112
assert_block_id(removed, [])
101113
assert_block_id(block_table, original_block_ids)
102114

103115
# 6 tokens are computed. Token 0 & 1 are out of the sliding window.
104116
# Block 1000 can be removed.
105-
removed = manager.remove_useless_blocks(block_table,
106-
6,
107-
is_first_call=False)
117+
removed = manager.remove_useless_blocks(block_table, 6)
108118
assert_block_id(removed, [original_block_ids[0]])
109119
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
110120

111121
# 7 tokens are computed. Token 0-2 are out of the sliding window.
112122
# Cannot remove new block as the block 1001 is still used by token 3.
113-
removed = manager.remove_useless_blocks(block_table,
114-
7,
115-
is_first_call=False)
123+
removed = manager.remove_useless_blocks(block_table, 7)
116124
assert_block_id(removed, [])
117125
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
118126

119127
# 8 tokens are computed. Token 0-3 are out of the sliding window.
120128
# Block 1001 can be removed and block 1000 is already removed.
121-
removed = manager.remove_useless_blocks(block_table,
122-
8,
123-
is_first_call=False)
129+
removed = manager.remove_useless_blocks(block_table, 8)
124130
assert_block_id(removed, [original_block_ids[1]])
125131
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
126132

127133
# 12 tokens are computed. Token 0-7 are out of the sliding window.
128134
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
129135
# sequence, and is expected to be evicted earlier than 1002, so the order
130136
# of removed blocks should be [1003, 1002].
131-
removed = manager.remove_useless_blocks(block_table,
132-
12,
133-
is_first_call=False)
137+
removed = manager.remove_useless_blocks(block_table, 12)
134138
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
135139
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])

vllm/v1/core/kv_cache_manager.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -127,31 +127,9 @@ def get_computed_blocks(
127127

128128
self.prefix_cache_stats.requests += 1
129129
if request.sampling_params.prompt_logprobs is None:
130-
# Check for cache hits
131-
# E.g., for a model with sliding window size 32 (2 * block_size)
132-
# computed_blocks = [NULL, 5, 2, NULL, 9, 7, 3, NULL]
133-
# We can have the first 3 blocks, 5 blocks, or 6 blocks as
134-
# the cached prefix, so the prefix_length should be:
135-
# prefix_length = [
136-
# PrefixLengthRange(3 * 16, 3 * 16),
137-
# PrefixLengthRange(6 * 16, 7 * 16)
138-
# ]
139-
prefix_length, computed_blocks = \
140-
self.specialized_manager.get_possible_cached_prefix(
141-
block_hashes)
142-
# E.g., num_computed_tokens = 7 * 16
143-
num_computed_tokens = prefix_length[-1].end
144-
# NOTE(woosuk): Since incomplete blocks are not eligible for
145-
# sharing, `num_computed_tokens` should always be a multiple of
146-
# `block_size`.
147-
assert num_computed_tokens % self.block_size == 0
148-
# E.g., computed_blocks = [NULL, 5, 2, NULL, 9, 7, 3]
149-
computed_blocks = computed_blocks[:num_computed_tokens //
150-
self.block_size]
151-
# E.g., computed_blocks = [NULL, NULL, NULL, NULL, 9, 7, 3]
152-
self._free_useless_blocks(computed_blocks,
153-
num_computed_tokens,
154-
touched=False)
130+
computed_blocks = (self.specialized_manager.
131+
get_longest_cached_prefix(block_hashes))
132+
num_computed_tokens = len(computed_blocks) * self.block_size
155133

156134
self.prefix_cache_stats.queries += len(block_hashes)
157135
self.prefix_cache_stats.hits += len(computed_blocks)
@@ -196,20 +174,23 @@ def allocate_slots(
196174

197175
new_computed_blocks = new_computed_blocks or []
198176

177+
req_blocks = self.req_to_blocks[request.request_id]
178+
179+
# We can free blocks that are no longer needed even if we cannot
180+
# schedule this request due to the limit of free blocks.
181+
# Should call this function before allocating new blocks to reduce
182+
# the number of evicted blocks.
183+
removed_blocks = self.specialized_manager.remove_useless_blocks(
184+
req_blocks, request.num_computed_tokens)
185+
self.block_pool.free_blocks(removed_blocks)
186+
199187
# The number of computed tokens is the number of computed tokens plus
200188
# the new prefix caching hits
201189
num_computed_tokens = (request.num_computed_tokens +
202190
len(new_computed_blocks) * self.block_size)
203191
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
204192
self.block_size)
205-
req_blocks = self.req_to_blocks[request.request_id]
206-
# We can free blocks that are no longer needed even if we cannot
207-
# schedule this request due to the limit of free blocks.
208-
# Should call this function before allocating new blocks to reduce
209-
# the number of evicted blocks.
210-
self._free_useless_blocks(req_blocks,
211-
request.num_computed_tokens,
212-
touched=True)
193+
213194
num_new_blocks = (num_required_blocks - len(req_blocks) -
214195
len(new_computed_blocks))
215196

@@ -231,12 +212,6 @@ def allocate_slots(
231212
"Computed blocks should be empty when "
232213
"prefix caching is disabled")
233214

234-
# Should call this function before allocating new blocks to reduce
235-
# the number of evicted blocks.
236-
self._free_useless_blocks(req_blocks,
237-
request.num_computed_tokens,
238-
touched=True)
239-
240215
# Append the new computed blocks to the request blocks until now to
241216
# avoid the case where the new blocks cannot be allocated.
242217
req_blocks.extend(new_computed_blocks)
@@ -398,6 +373,6 @@ def _free_useless_blocks(self, req_blocks: list[KVCacheBlock],
398373
# The first call always comes from `get_computed_blocks` which
399374
# passes `touched=False`.
400375
removed_blocks = self.specialized_manager.remove_useless_blocks(
401-
req_blocks, num_computed_tokens, is_first_call=not touched)
376+
req_blocks, num_computed_tokens)
402377
if touched:
403378
self.block_pool.free_blocks(removed_blocks)

vllm/v1/core/kv_cache_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -686,11 +686,3 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
686686
kv_cache_config.num_blocks = min_num_blocks
687687

688688
return kv_cache_configs
689-
690-
691-
class PrefixLengthRange(NamedTuple):
692-
"""
693-
A closed interval [start, end] representing a range of valid prefix lengths.
694-
"""
695-
start: int
696-
end: int

0 commit comments

Comments
 (0)