3
3
import torch
4
4
5
5
from vllm .v1 .core .block_pool import BlockPool
6
- from vllm .v1 .core .kv_cache_utils import (BlockHashType , KVCacheBlock ,
7
- PrefixLengthRange )
6
+ from vllm .v1 .core .kv_cache_utils import BlockHashType , KVCacheBlock
8
7
from vllm .v1 .core .specialized_manager import SlidingWindowManager
9
8
from vllm .v1 .kv_cache_interface import SlidingWindowSpec
10
9
@@ -22,35 +21,50 @@ def test_sliding_window_possible_cached_prefix():
22
21
block_pool = BlockPool (num_gpu_blocks = 100 , enable_caching = True )
23
22
manager = SlidingWindowManager (sliding_window_spec , block_pool )
24
23
25
- block_is_cached = [
24
+ def run_one_case (block_is_cached , expect_length ):
25
+ block_hash_list = [
26
+ BlockHashType (i , ()) for i in range (len (block_is_cached ))
27
+ ]
28
+
29
+ block_pool .cached_block_hash_to_block .clear ()
30
+
31
+ # Mock the block pool with the cached blocks
32
+ for i , (block_hash ,
33
+ is_cached ) in enumerate (zip (block_hash_list , block_is_cached )):
34
+ if is_cached :
35
+ block_pool .cached_block_hash_to_block [block_hash ] = {
36
+ i : block_pool .blocks [i + 10 ]
37
+ }
38
+
39
+ computed_blocks = manager .get_longest_cached_prefix (block_hash_list )
40
+ assert len (computed_blocks ) == expect_length
41
+
42
+ assert all (block == block_pool .get_null_block ()
43
+ for block in computed_blocks [:expect_length - 2 ])
44
+ for i in range (2 ):
45
+ if i < expect_length :
46
+ block_index = expect_length - i - 1
47
+ assert computed_blocks [
48
+ block_index ].block_id == block_index + 10
49
+
50
+ run_one_case ([False ] * 10 , 0 )
51
+ run_one_case ([True ], 1 )
52
+ run_one_case ([True , False ], 1 )
53
+ run_one_case ([True , True ], 2 )
54
+ run_one_case ([True , True , False ], 2 )
55
+ run_one_case ([True , True , True ], 3 )
56
+ run_one_case ([True , True , True , False ], 3 )
57
+ run_one_case ([
26
58
True , True , False , True , False , False , True , True , False , True , True ,
27
59
True
28
- ]
29
- block_hash_list = [
30
- BlockHashType (i , ()) for i in range (len (block_is_cached ))
31
- ]
32
-
33
- # Mock the block pool with the cached blocks
34
- for i , (block_hash ,
35
- is_cached ) in enumerate (zip (block_hash_list , block_is_cached )):
36
- if is_cached :
37
- block_pool .cached_block_hash_to_block [block_hash ] = {
38
- i : block_pool .blocks [i + 10 ]
39
- }
40
-
41
- ranges , computed_blocks = manager .get_possible_cached_prefix (
42
- block_hash_list )
43
- assert ranges == [
44
- PrefixLengthRange (0 , 4 ),
45
- PrefixLengthRange (16 , 16 ),
46
- PrefixLengthRange (22 , 24 )
47
- ]
48
- expected_computed_blocks = [
49
- block_pool .blocks [i +
50
- 10 ] if is_cached else block_pool .get_null_block ()
51
- for i , is_cached in enumerate (block_is_cached )
52
- ]
53
- assert computed_blocks == expected_computed_blocks
60
+ ], 12 )
61
+ run_one_case ([
62
+ True , True , False , True , False , False , True , True , False , False , False
63
+ ], 8 )
64
+ run_one_case ([
65
+ True , True , False , True , False , False , True , True , False , False , False ,
66
+ True
67
+ ], 8 )
54
68
55
69
56
70
def test_sliding_window_remove_useless_blocks ():
@@ -87,49 +101,39 @@ def assert_block_id(block_table, ids):
87
101
1000 , 1001 , 1002 , 1003 , 1004 , 1005 , 1006 , 1007 , 1008 , 1009 , 1010
88
102
]
89
103
block_table = id_to_block_table (original_block_ids )
90
- removed = manager .remove_useless_blocks (block_table , 0 , is_first_call = True )
104
+ removed = manager .remove_useless_blocks (block_table , 0 )
91
105
assert_block_id (removed , [])
92
106
assert_block_id (block_table , original_block_ids )
93
107
94
108
# 5 tokens are computed. Only token 0 is out of the sliding window. As
95
109
# block 1000 also contains token 1 that is in the sliding window, block 1000
96
110
# cannot be removed.
97
- removed = manager .remove_useless_blocks (block_table ,
98
- 5 ,
99
- is_first_call = False )
111
+ removed = manager .remove_useless_blocks (block_table , 5 )
100
112
assert_block_id (removed , [])
101
113
assert_block_id (block_table , original_block_ids )
102
114
103
115
# 6 tokens are computed. Token 0 & 1 are out of the sliding window.
104
116
# Block 1000 can be removed.
105
- removed = manager .remove_useless_blocks (block_table ,
106
- 6 ,
107
- is_first_call = False )
117
+ removed = manager .remove_useless_blocks (block_table , 6 )
108
118
assert_block_id (removed , [original_block_ids [0 ]])
109
119
assert_block_id (block_table , [null_block_id ] + original_block_ids [1 :])
110
120
111
121
# 7 tokens are computed. Token 0-2 are out of the sliding window.
112
122
# Cannot remove new block as the block 1001 is still used by token 3.
113
- removed = manager .remove_useless_blocks (block_table ,
114
- 7 ,
115
- is_first_call = False )
123
+ removed = manager .remove_useless_blocks (block_table , 7 )
116
124
assert_block_id (removed , [])
117
125
assert_block_id (block_table , [null_block_id ] + original_block_ids [1 :])
118
126
119
127
# 8 tokens are computed. Token 0-3 are out of the sliding window.
120
128
# Block 1001 can be removed and block 1000 is already removed.
121
- removed = manager .remove_useless_blocks (block_table ,
122
- 8 ,
123
- is_first_call = False )
129
+ removed = manager .remove_useless_blocks (block_table , 8 )
124
130
assert_block_id (removed , [original_block_ids [1 ]])
125
131
assert_block_id (block_table , [null_block_id ] * 2 + original_block_ids [2 :])
126
132
127
133
# 12 tokens are computed. Token 0-7 are out of the sliding window.
128
134
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
129
135
# sequence, and is expected to be evicted earlier than 1002, so the order
130
136
# of removed blocks should be [1003, 1002].
131
- removed = manager .remove_useless_blocks (block_table ,
132
- 12 ,
133
- is_first_call = False )
137
+ removed = manager .remove_useless_blocks (block_table , 12 )
134
138
assert_block_id (removed , [original_block_ids [3 ], original_block_ids [2 ]])
135
139
assert_block_id (block_table , [null_block_id ] * 4 + original_block_ids [4 :])
0 commit comments