@@ -49,9 +49,10 @@ def test_prefill():
49
49
unique_token_ids = [3 ] * 7
50
50
all_token_ids = common_token_ids + unique_token_ids
51
51
req0 = make_request ("0" , all_token_ids )
52
- computed_blocks = manager .get_computed_blocks (req0 )
52
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
53
53
assert len (req0 .kv_block_hashes ) == 3
54
54
assert not computed_blocks
55
+ assert num_computed_tokens == 0
55
56
blocks = manager .allocate_slots (req0 , 55 , computed_blocks )
56
57
assert [b .block_id for b in blocks ] == [0 , 1 , 2 , 3 , 4 ]
57
58
@@ -73,9 +74,10 @@ def test_prefill():
73
74
# Incomplete 1 block (5 tokens)
74
75
unique_token_ids = [3 ] * 5
75
76
req1 = make_request ("1" , common_token_ids + unique_token_ids )
76
- computed_blocks = manager .get_computed_blocks (req1 )
77
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
77
78
assert len (req1 .kv_block_hashes ) == 3
78
79
assert [b .block_id for b in computed_blocks ] == [0 , 1 , 2 ]
80
+ assert num_computed_tokens == 3 * 16
79
81
num_new_tokens = 53 - 3 * 16
80
82
blocks = manager .allocate_slots (req1 , num_new_tokens , computed_blocks )
81
83
assert [b .block_id for b in blocks ] == [5 , 6 ]
@@ -91,7 +93,7 @@ def test_prefill():
91
93
# All blocks should be available.
92
94
assert manager .free_block_queue .num_free_blocks == 10
93
95
# The order should be
94
- # [unallocated (7, 8)]
96
+ # [unallocated (7, 8, 9 )]
95
97
# [unique_req0 (4, 3)]
96
98
# [unique_req1 (6, 5)]
97
99
# [common (2, 1, 0)]
@@ -103,9 +105,10 @@ def test_prefill():
103
105
# Incomplete 1 block (6 tokens)
104
106
unique_token_ids = [3 ] * 6
105
107
req2 = make_request ("2" , common_token_ids + unique_token_ids )
106
- computed_blocks = manager .get_computed_blocks (req2 )
108
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req2 )
107
109
assert len (req2 .kv_block_hashes ) == 3
108
110
assert [b .block_id for b in computed_blocks ] == [0 , 1 , 2 ]
111
+ assert num_computed_tokens == 3 * 16
109
112
num_new_tokens = 53 - 3 * 16
110
113
blocks = manager .allocate_slots (req2 , num_new_tokens , computed_blocks )
111
114
assert [b .block_id for b in blocks ] == [7 , 8 ]
@@ -123,8 +126,9 @@ def test_prefill():
123
126
124
127
# Cache miss and eviction.
125
128
req3 = make_request ("3" , [99 ] * (16 * 9 ))
126
- computed_blocks = manager .get_computed_blocks (req3 )
129
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req3 )
127
130
assert not computed_blocks
131
+ assert num_computed_tokens == 0
128
132
blocks = manager .allocate_slots (req3 , 16 * 9 , computed_blocks )
129
133
# This block ID order also checks the eviction order.
130
134
assert [b .block_id for b in blocks ] == [9 , 4 , 3 , 6 , 5 , 8 , 7 , 2 , 1 , 0 ]
@@ -150,8 +154,9 @@ def test_decode():
150
154
# Incomplete 1 block (7 tokens)
151
155
unique_token_ids = [3 ] * 7
152
156
req0 = make_request ("0" , common_token_ids + unique_token_ids )
153
- computed_blocks = manager .get_computed_blocks (req0 )
157
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
154
158
assert not computed_blocks
159
+ assert num_computed_tokens == 0
155
160
blocks = manager .allocate_slots (req0 , 55 , computed_blocks )
156
161
assert [b .block_id for b in blocks ] == [0 , 1 , 2 , 3 , 4 ]
157
162
@@ -197,16 +202,18 @@ def test_evict():
197
202
198
203
last_token_id = 5 * 16 + 7
199
204
req0 = make_request ("0" , list (range (last_token_id )))
200
- computed_blocks = manager .get_computed_blocks (req0 )
205
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
201
206
assert not computed_blocks
207
+ assert num_computed_tokens == 0
202
208
blocks = manager .allocate_slots (req0 , 5 * 16 + 7 , computed_blocks )
203
209
assert len (blocks ) == 7 # 5 full + 1 partial + 1 preallocated
204
210
205
211
# 3 blocks.
206
212
req1 = make_request ("1" , list (range (last_token_id ,
207
213
last_token_id + 3 * 16 )))
208
- computed_blocks = manager .get_computed_blocks (req1 )
214
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
209
215
assert not computed_blocks
216
+ assert num_computed_tokens == 0
210
217
blocks = manager .allocate_slots (req1 , 3 * 16 , computed_blocks )
211
218
assert len (blocks ) == 3 # 3 full blocks
212
219
last_token_id += 3 * 16
@@ -222,8 +229,9 @@ def test_evict():
222
229
223
230
# Touch the first 2 blocks.
224
231
req2 = make_request ("2" , list (range (2 * 16 + 3 )))
225
- computed_blocks = manager .get_computed_blocks (req2 )
232
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req2 )
226
233
assert [b .block_id for b in computed_blocks ] == [0 , 1 ]
234
+ assert num_computed_tokens == 2 * 16
227
235
blocks = manager .allocate_slots (req2 , 3 , computed_blocks )
228
236
assert [b .block_id for b in blocks ] == [6 , 5 ]
229
237
assert manager .free_block_queue .num_free_blocks == 6
@@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
247
255
# Allocate 1 block and cache it.
248
256
num_tokens = block_size * 1
249
257
req = make_request ("0" , list (range (num_tokens )))
250
- computed_blocks = manager .get_computed_blocks (req )
258
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req )
251
259
assert not computed_blocks
260
+ assert num_computed_tokens == 0
252
261
blocks = manager .allocate_slots (req , num_tokens , computed_blocks )
253
262
assert len (blocks ) == 1
254
263
@@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
258
267
# Allocate a new block that's not full, make sure hash info on the
259
268
# block is cleared.
260
269
req = make_request ("1" , list (range (num_tokens - 1 )))
261
- computed_blocks = manager .get_computed_blocks (req )
270
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req )
262
271
assert not computed_blocks
272
+ assert num_computed_tokens == 0
263
273
blocks = manager .allocate_slots (req , num_tokens - 1 , computed_blocks )
264
274
assert len (blocks ) == 1
265
275
@@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
284
294
# Allocate a block and cache it.
285
295
num_tokens = block_size * 1
286
296
req0 = make_request ("0" , list (range (num_tokens )))
287
- computed_blocks = manager .get_computed_blocks (req0 )
297
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
288
298
assert not computed_blocks
299
+ assert num_computed_tokens == 0
289
300
blocks = manager .allocate_slots (req0 , num_tokens , computed_blocks )
290
301
assert len (blocks ) == 1
291
302
assert blocks [0 ].block_id == 0
292
303
293
304
# Allocate another block.
294
305
req1 = make_request ("1" , list (range (num_tokens , num_tokens * 2 )))
295
- computed_blocks = manager .get_computed_blocks (req1 )
306
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
296
307
assert not computed_blocks
308
+ assert num_computed_tokens == 0
297
309
blocks = manager .allocate_slots (req1 , num_tokens , computed_blocks )
298
310
assert len (blocks ) == 1
299
311
assert blocks [0 ].block_id == 1
@@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
305
317
# Now if we have a cache hit on the first block, we should evict the second
306
318
# cached block rather than the first one.
307
319
req2 = make_request ("2" , list (range (num_tokens * 2 )))
308
- computed_blocks = manager .get_computed_blocks (req2 )
320
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req2 )
309
321
assert len (computed_blocks ) == 1
310
322
assert computed_blocks [0 ].block_id == 0
323
+ assert num_computed_tokens == block_size
311
324
312
325
blocks = manager .allocate_slots (req2 , num_tokens * 2 - num_tokens ,
313
326
computed_blocks )
@@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():
331
344
332
345
req1 = make_request ("1" , list (range (10 ))) # 2 blocks and some more
333
346
334
- computed_blocks = manager .get_computed_blocks (req1 )
347
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
335
348
assert not computed_blocks
349
+ assert num_computed_tokens == 0
336
350
blocks = manager .allocate_slots (req1 , 10 , computed_blocks )
337
351
assert len (blocks ) == 3
338
352
@@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():
341
355
342
356
# No caching.
343
357
req2 = make_request ("2" , list (range (16 ))) # shared prefix
344
- computed_blocks = manager .get_computed_blocks (req2 )
358
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req2 )
345
359
assert not computed_blocks
360
+ assert num_computed_tokens == 0
346
361
blocks = manager .allocate_slots (req2 , 16 , computed_blocks )
347
362
assert len (blocks ) == 4
348
363
349
364
# New requests should not have any blocks.
350
365
req3 = make_request ("3" , list (range (4 )))
351
- computed_blocks = manager .get_computed_blocks (req3 )
366
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req3 )
352
367
assert not computed_blocks
368
+ assert num_computed_tokens == 0
353
369
blocks = manager .allocate_slots (req3 , 4 , computed_blocks )
354
370
assert not blocks
355
371
@@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
371
387
num_preallocated_blocks = cdiv (num_preallocate_tokens , block_size )
372
388
373
389
req = make_request ("0" , list (range (block_size * 30 )))
374
- computed_blocks = manager .get_computed_blocks (req )
390
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req )
375
391
assert not computed_blocks
392
+ assert num_computed_tokens == 0
376
393
# Just ask for 1 block.
377
394
blocks = manager .allocate_slots (req , block_size , computed_blocks )
378
395
req .num_computed_tokens = block_size
@@ -469,10 +486,11 @@ def test_mm_prefix_caching():
469
486
all_token_ids ,
470
487
mm_positions = mm_positions ,
471
488
mm_hashes = mm_hashes )
472
- computed_blocks = manager .get_computed_blocks (req0 )
489
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
473
490
474
491
# Completed block should have hashes with extra keys.
475
492
assert not computed_blocks
493
+ assert num_computed_tokens == 0
476
494
assert len (req0 .kv_block_hashes ) == 3
477
495
assert req0 .kv_block_hashes [0 ].extra_keys == ("aaa" , )
478
496
assert req0 .kv_block_hashes [1 ].extra_keys == ("aaa" , "bbb" )
@@ -503,8 +521,9 @@ def test_mm_prefix_caching():
503
521
all_token_ids ,
504
522
mm_positions = mm_positions ,
505
523
mm_hashes = mm_hashes )
506
- computed_blocks = manager .get_computed_blocks (req1 )
524
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
507
525
assert len (computed_blocks ) == 3
526
+ assert num_computed_tokens == 3 * 16
508
527
509
528
510
529
def test_prefill_not_enough_free_blocks_with_computed_blocks ():
@@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
527
546
# | Common-0 | Common-1 | Common-2 | ... |
528
547
common_token_ids = [i for i in range (3 ) for _ in range (16 )]
529
548
req0 = make_request ("0" , common_token_ids )
530
- computed_blocks = manager .get_computed_blocks (req0 )
549
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
531
550
assert not computed_blocks
551
+ assert num_computed_tokens == 0
532
552
manager .allocate_slots (req0 , 48 , computed_blocks )
533
553
block_part0 = manager .req_to_blocks [req0 .request_id ]
534
554
535
555
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
536
556
req1 = make_request ("1" , common_token_ids * 2 )
537
- computed_blocks = manager .get_computed_blocks (req1 )
557
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
538
558
assert computed_blocks == block_part0
559
+ assert num_computed_tokens == 3 * 16
539
560
manager .allocate_slots (req1 , 48 , computed_blocks )
540
561
block_part1 = manager .req_to_blocks [req1 .request_id ]
541
562
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
@@ -547,17 +568,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
547
568
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
548
569
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
549
570
req2 = make_request ("2" , [7 ] * block_size * 2 )
550
- computed_blocks = manager .get_computed_blocks (req2 )
571
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req2 )
551
572
assert not computed_blocks
573
+ assert num_computed_tokens == 0
552
574
manager .allocate_slots (req2 , block_size * 2 , computed_blocks )
553
575
554
576
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
555
577
# but it cannot be allocated due to insufficient free blocks (2).
556
578
# In this case, the ref_cnt of the computed blocks should not be changed.
557
579
assert manager .free_block_queue .num_free_blocks == 5
558
580
req3 = make_request ("3" , common_token_ids * 3 )
559
- computed_blocks = manager .get_computed_blocks (req3 )
581
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req3 )
560
582
assert computed_blocks == block_part1
583
+ assert num_computed_tokens == 6 * 16
561
584
# Req3 cannot be allocated.
562
585
assert manager .allocate_slots (req3 , 48 , computed_blocks ) is None
563
586
# Block 0-2 are used by Req 1.
0 commit comments