@@ -285,6 +285,29 @@ def free(self, request: Request) -> None:
285
285
if block .ref_cnt == 0 :
286
286
self .free_block_queue .append (block )
287
287
288
+ def uncache_blocks (self , request : Request ) -> int :
289
+ """Uncache the blocks that are no longer full based on the
290
+ num_computed_tokens in the given request. This happens when
291
+ the blocks were full and cached due to speculative tokens, but the
292
+ speculative tokens are not accepted.
293
+
294
+ Args:
295
+ request: The request.
296
+
297
+ Returns:
298
+ The number of uncached blocks.
299
+ """
300
+ blocks = self .req_to_blocks [request .request_id ]
301
+ num_computed_tokens = request .num_computed_tokens
302
+ num_full_blocks = num_computed_tokens // self .block_size
303
+ num_uncached_blocks = 0
304
+ for block in blocks [num_full_blocks :]:
305
+ # If the block is not cached, the following blocks are not cached.
306
+ if not self ._maybe_evict_cached_block (block ):
307
+ break
308
+ num_uncached_blocks += 1
309
+ return num_uncached_blocks
310
+
288
311
def reset_prefix_cache (self ) -> bool :
289
312
"""Reset prefix cache. This function may be used in RLHF
290
313
flows to invalid prefix caching after the weights are updated,
@@ -386,21 +409,24 @@ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
386
409
387
410
# If the block is cached, evict it.
388
411
if self .enable_caching :
389
- self ._evict_cached_block (curr_block )
412
+ self ._maybe_evict_cached_block (curr_block )
390
413
391
414
curr_block .incr_ref ()
392
415
ret .append (curr_block )
393
416
idx += 1
394
417
395
418
return ret
396
419
397
- def _evict_cached_block (self , block : KVCacheBlock ) -> None :
420
+ def _maybe_evict_cached_block (self , block : KVCacheBlock ) -> bool :
398
421
"""
399
422
If a block is cached in `cached_block_hash_to_block`, we reset its hash
400
423
metadata and evict it from the cache.
401
424
402
425
Args:
403
426
block: The block to evict.
427
+
428
+ Returns:
429
+ True if the block is evicted, False otherwise.
404
430
"""
405
431
block_hash = block .block_hash
406
432
if block_hash and block_hash in self .cached_block_hash_to_block :
@@ -410,6 +436,9 @@ def _evict_cached_block(self, block: KVCacheBlock) -> None:
410
436
if len (self .cached_block_hash_to_block [block_hash ]) == 0 :
411
437
del self .cached_block_hash_to_block [block_hash ]
412
438
439
+ return True
440
+ return False
441
+
413
442
def _get_cached_block (self ,
414
443
block_hash : BlockHashType ) -> Optional [KVCacheBlock ]:
415
444
"""Get a cached block by the block hash, or None if cache miss.
0 commit comments