Skip to content

Commit 941f770

Browse files
committed
update based on review
Signed-off-by: Chen Zhang <[email protected]>
1 parent 2ca014e commit 941f770

File tree

4 files changed

+43
-46
lines changed

4 files changed

+43
-46
lines changed

tests/v1/core/test_specialized_manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def run_one_case(block_is_cached, expect_length):
3636
i: block_pool.blocks[i + 10]
3737
}
3838

39-
computed_blocks = manager.get_longest_cached_prefix(block_hash_list)
39+
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
4040
assert len(computed_blocks) == expect_length
4141

42-
assert all(block == block_pool.get_null_block()
42+
assert all(block == block_pool.null_block
4343
for block in computed_blocks[:expect_length - 2])
4444
for i in range(2):
4545
if i < expect_length:
@@ -67,7 +67,7 @@ def run_one_case(block_is_cached, expect_length):
6767
], 8)
6868

6969

70-
def test_sliding_window_remove_useless_blocks():
70+
def test_sliding_window_remove_skipped_blocks():
7171
sliding_window_spec = SlidingWindowSpec(
7272
block_size=2,
7373
num_kv_heads=1,
@@ -81,59 +81,58 @@ def test_sliding_window_remove_useless_blocks():
8181

8282
manager = SlidingWindowManager(sliding_window_spec, block_pool)
8383

84-
null_block_id = block_pool.get_null_block().block_id
84+
null_block_id = block_pool.null_block.block_id
8585

8686
def id_to_block_table(ids):
8787
return [
8888
KVCacheBlock(id_)
89-
if id_ != null_block_id else block_pool.get_null_block()
90-
for id_ in ids
89+
if id_ != null_block_id else block_pool.null_block for id_ in ids
9190
]
9291

9392
def assert_block_id(block_table, ids):
9493
for block, id_ in zip(block_table, ids):
9594
if id_ == null_block_id:
96-
assert block == block_pool.get_null_block()
95+
assert block == block_pool.null_block
9796
else:
9897
assert block.block_id == id_
9998

10099
original_block_ids = [
101100
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
102101
]
103102
block_table = id_to_block_table(original_block_ids)
104-
removed = manager.remove_useless_blocks(block_table, 0)
103+
removed = manager.remove_skipped_blocks(block_table, 0)
105104
assert_block_id(removed, [])
106105
assert_block_id(block_table, original_block_ids)
107106

108107
# 5 tokens are computed. Only token 0 is out of the sliding window. As
109108
# block 1000 also contains token 1 that is in the sliding window, block 1000
110109
# cannot be removed.
111-
removed = manager.remove_useless_blocks(block_table, 5)
110+
removed = manager.remove_skipped_blocks(block_table, 5)
112111
assert_block_id(removed, [])
113112
assert_block_id(block_table, original_block_ids)
114113

115114
# 6 tokens are computed. Token 0 & 1 are out of the sliding window.
116115
# Block 1000 can be removed.
117-
removed = manager.remove_useless_blocks(block_table, 6)
116+
removed = manager.remove_skipped_blocks(block_table, 6)
118117
assert_block_id(removed, [original_block_ids[0]])
119118
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
120119

121120
# 7 tokens are computed. Token 0-2 are out of the sliding window.
122121
# Cannot remove new block as the block 1001 is still used by token 3.
123-
removed = manager.remove_useless_blocks(block_table, 7)
122+
removed = manager.remove_skipped_blocks(block_table, 7)
124123
assert_block_id(removed, [])
125124
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
126125

127126
# 8 tokens are computed. Token 0-3 are out of the sliding window.
128127
# Block 1001 can be removed and block 1000 is already removed.
129-
removed = manager.remove_useless_blocks(block_table, 8)
128+
removed = manager.remove_skipped_blocks(block_table, 8)
130129
assert_block_id(removed, [original_block_ids[1]])
131130
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
132131

133132
# 12 tokens are computed. Token 0-7 are out of the sliding window.
134133
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
135134
# sequence, and is expected to be evicted earlier than 1002, so the order
136135
# of removed blocks should be [1003, 1002].
137-
removed = manager.remove_useless_blocks(block_table, 12)
136+
removed = manager.remove_skipped_blocks(block_table, 12)
138137
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
139138
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])

vllm/v1/core/block_pool.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool):
5454
# To represent a placeholder block with block_id=0.
5555
# The ref_cnt of null_block is not maintained, needs special care to
5656
# avoid freeing it.
57-
self._null_block = self.free_block_queue.popleft()
57+
self.null_block = self.free_block_queue.popleft()
5858

5959
def get_cached_block(self,
6060
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
@@ -220,7 +220,7 @@ def touch(self, blocks: list[KVCacheBlock]) -> None:
220220
for block in blocks:
221221
# ref_cnt=0 means this block is in the free list (i.e. eviction
222222
# candidate), so remove it.
223-
if block.ref_cnt == 0 and block != self._null_block:
223+
if block.ref_cnt == 0 and block != self.null_block:
224224
self.free_block_queue.remove(block)
225225
block.incr_ref()
226226

@@ -235,7 +235,7 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
235235
for block in ordered_blocks:
236236
block.decr_ref()
237237
# null_block should not be added to the free list.
238-
if block.ref_cnt == 0 and block != self._null_block:
238+
if block.ref_cnt == 0 and block != self.null_block:
239239
self.free_block_queue.append(block)
240240

241241
def reset_prefix_cache(self) -> bool:
@@ -279,11 +279,3 @@ def get_usage(self) -> float:
279279
The KV cache usage (between 0.0 and 1.0).
280280
"""
281281
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
282-
283-
def get_null_block(self) -> KVCacheBlock:
284-
"""Get the null block.
285-
286-
Returns:
287-
The null block.
288-
"""
289-
return self._null_block

vllm/v1/core/kv_cache_manager.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,16 @@ def get_computed_blocks(
127127

128128
self.prefix_cache_stats.requests += 1
129129
if request.sampling_params.prompt_logprobs is None:
130-
computed_blocks = (self.specialized_manager.
131-
get_longest_cached_prefix(block_hashes))
132-
num_computed_tokens = len(computed_blocks) * self.block_size
130+
computed_blocks = (
131+
self.specialized_manager.find_longest_cache_hit(block_hashes))
133132

134133
self.prefix_cache_stats.queries += len(block_hashes)
135134
self.prefix_cache_stats.hits += len(computed_blocks)
136135

136+
# NOTE(woosuk): Since incomplete blocks are not eligible for
137+
# sharing, `num_computed_tokens` is always a multiple of
138+
# `block_size`.
139+
num_computed_tokens = len(computed_blocks) * self.block_size
137140
return computed_blocks, num_computed_tokens
138141
else:
139142
# Skip cache hits for prompt logprobs
@@ -176,11 +179,13 @@ def allocate_slots(
176179

177180
req_blocks = self.req_to_blocks[request.request_id]
178181

179-
# We can free blocks that are no longer needed even if we cannot
180-
# schedule this request due to the limit of free blocks.
182+
# Free the blocks that are skipped during the attention computation
183+
# (e.g., tokens outside the sliding window).
184+
# We can do this even if we cannot schedule this request due to
185+
# insufficient free blocks.
181186
# Should call this function before allocating new blocks to reduce
182187
# the number of evicted blocks.
183-
removed_blocks = self.specialized_manager.remove_useless_blocks(
188+
removed_blocks = self.specialized_manager.remove_skipped_blocks(
184189
req_blocks, request.num_computed_tokens)
185190
self.block_pool.free_blocks(removed_blocks)
186191

@@ -372,7 +377,7 @@ def _free_useless_blocks(self, req_blocks: list[KVCacheBlock],
372377
"""
373378
# The first call always comes from `get_computed_blocks` which
374379
# passes `touched=False`.
375-
removed_blocks = self.specialized_manager.remove_useless_blocks(
380+
removed_blocks = self.specialized_manager.remove_skipped_blocks(
376381
req_blocks, num_computed_tokens)
377382
if touched:
378383
self.block_pool.free_blocks(removed_blocks)

vllm/v1/core/specialized_manager.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def __init__(
3030
self.block_pool = block_pool
3131

3232
@abstractmethod
33-
def get_longest_cached_prefix(
33+
def find_longest_cache_hit(
3434
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
3535
"""
36-
Get the longest cached prefix of the blocks. If no cached prefix is
36+
Get the longest cache hit prefix of the blocks. If no cache hit is
3737
found, returns an empty list.
3838
3939
Args:
@@ -48,7 +48,7 @@ def get_longest_cached_prefix(
4848
raise NotImplementedError
4949

5050
@abstractmethod
51-
def remove_useless_blocks(self, blocks: list[KVCacheBlock],
51+
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
5252
num_computed_tokens: int) -> list[KVCacheBlock]:
5353
"""
5454
Remove the blocks that are no longer needed from. The removed blocks
@@ -66,7 +66,7 @@ def remove_useless_blocks(self, blocks: list[KVCacheBlock],
6666

6767
class FullAttentionManager(SpecializedManager):
6868

69-
def get_longest_cached_prefix(
69+
def find_longest_cache_hit(
7070
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
7171
computed_blocks: list[KVCacheBlock] = []
7272
for block_hash in block_hashes:
@@ -79,7 +79,7 @@ def get_longest_cached_prefix(
7979
break
8080
return computed_blocks
8181

82-
def remove_useless_blocks(self, blocks: list[KVCacheBlock],
82+
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
8383
num_computed_tokens: int) -> list[KVCacheBlock]:
8484
# No need to remove blocks for full attention.
8585
return []
@@ -91,9 +91,9 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec,
9191
block_pool: BlockPool):
9292
super().__init__(kv_cache_spec, block_pool)
9393
self.sliding_window = kv_cache_spec.sliding_window
94-
self._null_block = block_pool.get_null_block()
94+
self._null_block = block_pool.null_block
9595

96-
def get_longest_cached_prefix(
96+
def find_longest_cache_hit(
9797
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
9898
# TODO: reduce i by num_block_sliding_window when cache miss, to
9999
# optimize the time complexity from O(len(block_hashes)) to
@@ -102,22 +102,23 @@ def get_longest_cached_prefix(
102102
# which is good for low cache hit rate scenarios.
103103
computed_blocks: list[KVCacheBlock] = [self._null_block
104104
] * len(block_hashes)
105-
num_computed_blocks = 0
105+
num_contiguous_blocks = 0
106106

107107
for i in range(len(block_hashes) - 1, -1, -1):
108108
if cached_block := self.block_pool.get_cached_block(
109109
block_hashes[i]):
110110
computed_blocks[i] = cached_block
111-
num_computed_blocks += 1
112-
if num_computed_blocks * self.block_size >= self.sliding_window:
113-
del computed_blocks[i + num_computed_blocks:]
111+
num_contiguous_blocks += 1
112+
if (num_contiguous_blocks * self.block_size
113+
>= self.sliding_window):
114+
del computed_blocks[i + num_contiguous_blocks:]
114115
return computed_blocks
115116
else:
116-
num_computed_blocks = 0
117-
del computed_blocks[num_computed_blocks:]
117+
num_contiguous_blocks = 0
118+
del computed_blocks[num_contiguous_blocks:]
118119
return computed_blocks
119120

120-
def remove_useless_blocks(self, blocks: list[KVCacheBlock],
121+
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
121122
num_computed_tokens: int) -> list[KVCacheBlock]:
122123
# Remove the blocks that are no longer be in the sliding window.
123124
last_useful_token = num_computed_tokens - self.sliding_window
@@ -137,7 +138,7 @@ def remove_useless_blocks(self, blocks: list[KVCacheBlock],
137138

138139
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
139140
FullAttentionSpec: FullAttentionManager,
140-
SlidingWindowSpec: SlidingWindowManager
141+
SlidingWindowSpec: SlidingWindowManager,
141142
}
142143

143144

0 commit comments

Comments
 (0)