Skip to content

Commit d91457d

Browse files
authored
[V1] Add kv cache utils tests. (#11513)
Signed-off-by: xcnick <[email protected]>
1 parent fbf2564 commit d91457d

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import pytest
2+
3+
from vllm.inputs import token_inputs
4+
from vllm.sampling_params import SamplingParams
5+
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
6+
KVCacheBlock,
7+
generate_block_hash_extra_keys,
8+
hash_block_tokens,
9+
hash_request_tokens)
10+
from vllm.v1.request import Request
11+
12+
13+
def make_request(request_id,
14+
prompt_token_ids,
15+
mm_positions=None,
16+
mm_hashes=None):
17+
return Request(
18+
request_id=request_id,
19+
inputs=token_inputs(
20+
prompt_token_ids=prompt_token_ids,
21+
multi_modal_placeholders={"image": mm_positions}
22+
if mm_positions else None,
23+
multi_modal_hashes=mm_hashes,
24+
),
25+
sampling_params=SamplingParams(max_tokens=17),
26+
eos_token_id=100,
27+
arrival_time=0,
28+
lora_request=None,
29+
)
30+
31+
32+
def test_kv_cache_block():
33+
# Test KVCacheBlock initialization
34+
block = KVCacheBlock(block_id=0)
35+
assert block.block_id == 0
36+
assert block.ref_cnt == 0
37+
assert block.block_hash is None
38+
39+
# Test reference count manipulation
40+
block.incr_ref()
41+
assert block.ref_cnt == 1
42+
block.decr_ref()
43+
assert block.ref_cnt == 0
44+
45+
# Test block hash setting and resetting
46+
block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3))
47+
block.block_hash = block_hash
48+
assert block.block_hash == block_hash
49+
50+
block.reset_hash()
51+
assert block.block_hash is None
52+
53+
54+
def test_free_kv_cache_block_queue_initialization():
55+
# Test with a single block
56+
block = KVCacheBlock(block_id=0)
57+
queue = FreeKVCacheBlockQueue([block])
58+
assert queue.num_free_blocks == 1
59+
assert queue.free_list_head == block
60+
assert queue.free_list_tail == block
61+
62+
63+
def test_free_kv_cache_block_queue_operations():
64+
# Create a list of KVCacheBlock objects
65+
blocks = [KVCacheBlock(block_id=i) for i in range(5)]
66+
67+
# Create a FreeKVCacheBlockQueue with these blocks
68+
queue = FreeKVCacheBlockQueue(blocks)
69+
70+
# Check initial state
71+
assert queue.num_free_blocks == 5
72+
assert queue.free_list_head == blocks[0]
73+
assert queue.free_list_tail == blocks[4]
74+
75+
# Pop the first block
76+
block1 = queue.popleft()
77+
assert block1 == blocks[0]
78+
assert queue.num_free_blocks == 4
79+
assert queue.free_list_head == blocks[1]
80+
assert queue.free_list_tail == blocks[4]
81+
82+
# Remove a block from the middle
83+
block_to_remove = blocks[2]
84+
queue.remove(block_to_remove)
85+
assert queue.num_free_blocks == 3
86+
assert blocks[1].next_free_block == blocks[3]
87+
assert blocks[3].prev_free_block == blocks[1]
88+
89+
# Append a block back
90+
queue.append(block_to_remove)
91+
assert queue.num_free_blocks == 4
92+
assert queue.free_list_tail == block_to_remove
93+
assert block_to_remove.prev_free_block == blocks[4]
94+
assert block_to_remove.next_free_block is None
95+
96+
# Pop blocks until empty
97+
for _ in range(4):
98+
queue.popleft()
99+
assert queue.num_free_blocks == 0
100+
assert queue.free_list_head is None
101+
assert queue.free_list_tail is None
102+
103+
# Attempt to pop from an empty queue
104+
with pytest.raises(ValueError) as e:
105+
queue.popleft()
106+
assert str(e.value) == "No free blocks available"
107+
108+
109+
def test_free_kv_cache_block_queue_get_all_free_blocks():
110+
# Create a list of KVCacheBlock objects
111+
blocks = [KVCacheBlock(block_id=i) for i in range(5)]
112+
113+
# Create a FreeKVCacheBlockQueue with these blocks
114+
queue = FreeKVCacheBlockQueue(blocks)
115+
116+
# Check all blocks are correctly retrieved
117+
assert queue.get_all_free_blocks() == blocks
118+
119+
# Pop a block and check again
120+
queue.popleft()
121+
assert queue.get_all_free_blocks() == blocks[1:]
122+
123+
# Remove a block and check again
124+
block_to_remove = blocks[2]
125+
queue.remove(block_to_remove)
126+
assert queue.get_all_free_blocks() == blocks[1:2] + blocks[3:]
127+
128+
# Append a block back and check again
129+
queue.append(block_to_remove)
130+
assert queue.get_all_free_blocks() == \
131+
blocks[1:2] + blocks[3:] + [block_to_remove]
132+
133+
134+
def test_generate_block_hash_extra_keys():
135+
request = make_request(
136+
request_id=0,
137+
prompt_token_ids=[_ for _ in range(20)],
138+
mm_positions=[{
139+
"offset": 0,
140+
"length": 5
141+
}, {
142+
"offset": 10,
143+
"length": 5
144+
}],
145+
mm_hashes=["hash1", "hash2"],
146+
)
147+
148+
# Test with no extra keys
149+
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
150+
assert extra_keys == (("hash1", 0), )
151+
assert next_mm_idx == 1
152+
153+
# Test with partial overlap
154+
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0)
155+
assert extra_keys == (("hash1", 3), )
156+
assert next_mm_idx == 1
157+
158+
# Test with no overlap
159+
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
160+
assert extra_keys == ()
161+
assert next_mm_idx == 1
162+
163+
# Test with multiple extra keys
164+
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0)
165+
assert extra_keys == (("hash1", 0), ("hash2", 0))
166+
assert next_mm_idx == 2
167+
168+
169+
def test_generate_block_hash_extra_keys_no_mm_inputs():
170+
request = make_request(
171+
request_id=0,
172+
prompt_token_ids=[_ for _ in range(6)],
173+
mm_positions=None,
174+
mm_hashes=None,
175+
)
176+
177+
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
178+
assert extra_keys is None
179+
assert next_mm_idx == 0
180+
181+
182+
def test_hash_block_tokens():
183+
parent_block_hash = 123
184+
curr_block_token_ids = (1, 2, 3)
185+
extra_keys = ("key1", "key2")
186+
187+
block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
188+
extra_keys)
189+
assert isinstance(block_hash, BlockHashType)
190+
assert block_hash.hash_value == hash(
191+
(parent_block_hash, *curr_block_token_ids))
192+
assert block_hash.token_ids == curr_block_token_ids
193+
assert block_hash.extra_keys == extra_keys
194+
195+
196+
def test_hash_request_tokens():
197+
request = make_request(
198+
request_id=0,
199+
prompt_token_ids=[_ for _ in range(6)],
200+
mm_positions=[{
201+
"offset": 0,
202+
"length": 3
203+
}, {
204+
"offset": 3,
205+
"length": 3
206+
}],
207+
mm_hashes=["hash1", "hash2"],
208+
)
209+
210+
block_size = 3
211+
block_hashes = hash_request_tokens(block_size, request)
212+
213+
assert len(block_hashes) == 2
214+
assert isinstance(block_hashes[0], BlockHashType)
215+
assert isinstance(block_hashes[1], BlockHashType)
216+
217+
# Check the first block
218+
assert block_hashes[0].token_ids == (0, 1, 2)
219+
assert block_hashes[0].extra_keys == (("hash1", 0), )
220+
221+
# Check the second block
222+
assert block_hashes[1].token_ids == (3, 4, 5)
223+
assert block_hashes[1].extra_keys == (("hash2", 0), )
224+
225+
226+
def test_hash_request_tokens_no_mm_inputs():
227+
request = make_request(
228+
request_id=0,
229+
prompt_token_ids=[_ for _ in range(6)],
230+
mm_positions=None,
231+
mm_hashes=None,
232+
)
233+
234+
block_size = 3
235+
block_hashes = hash_request_tokens(block_size, request)
236+
237+
assert len(block_hashes) == 2
238+
assert block_hashes[0].token_ids == (0, 1, 2)
239+
assert block_hashes[0].extra_keys is None
240+
assert block_hashes[1].token_ids == (3, 4, 5)
241+
assert block_hashes[1].extra_keys is None

0 commit comments

Comments
 (0)