diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a6c0162d3f3..d598d12571f 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -51,7 +51,7 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(req0.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -76,7 +76,7 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -107,7 +107,7 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(req2.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -494,10 +494,11 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks assert num_computed_tokens == 0 - assert len(req0.kv_block_hashes) == 3 - assert req0.kv_block_hashes[0].extra_keys == ("aaa", ) - assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb") - assert req0.kv_block_hashes[2].extra_keys == ("bbb", ) + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("aaa", ) + assert block_hashes[1].extra_keys == ("aaa", "bbb") + assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -510,8 +511,8 @@ def test_mm_prefix_caching(): assert new_blocks is not None and len(new_blocks) == 0 # The just completed block should have hashes with extra keys. - assert len(req0.kv_block_hashes) == 4 - assert req0.kv_block_hashes[3].extra_keys == ("ccc", ) + assert len(block_hashes) == 4 + assert block_hashes[3].extra_keys == ("ccc", ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -613,7 +614,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) assert [b.block_id for b in blocks] == [4] diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index de349ec1209..afb3bd1fb75 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -72,6 +72,12 @@ def __init__( self.req_to_blocks: DefaultDict[str, List[KVCacheBlock]] = defaultdict(list) + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: DefaultDict[ + str, List[BlockHashType]] = defaultdict(list) + @property def usage(self) -> float: return 1.0 - (self.free_block_queue.num_free_blocks / @@ -97,11 +103,11 @@ def get_computed_blocks( computed_blocks = [] # The block hashes for the request may already be computed - # if the request was preempted and resumed. - if not request.kv_block_hashes: - request.set_kv_block_hashes( - hash_request_tokens(self.block_size, request)) - block_hashes = request.kv_block_hashes + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = hash_request_tokens(self.block_size, request) + self.req_to_block_hashes[request.request_id] = block_hashes for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not @@ -437,7 +443,8 @@ def _cache_full_blocks( full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ - num_cached_block_hashes = len(request.kv_block_hashes) + block_hashes = self.req_to_block_hashes[request.request_id] + num_cached_block_hashes = len(block_hashes) # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None @@ -470,7 +477,7 @@ def _cache_full_blocks( # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. - block_hash = request.kv_block_hashes[blk_idx] + block_hash = block_hashes[blk_idx] else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. @@ -492,9 +499,17 @@ def _cache_full_blocks( # Compute the hash of the current block. block_hash = hash_block_tokens(prev_block_hash_value, block_tokens, extra_keys) - request.append_kv_block_hashes(block_hash) + block_hashes.append(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash self.cached_block_hash_to_block[block_hash][blk.block_id] = blk prev_block_hash_value = block_hash.hash_value + + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 35d9424f942..1aa34ee3860 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -579,6 +579,7 @@ def finish_requests( def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 89b39ea615d..bb4d2c19197 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.core.kv_cache_utils import BlockHashType class Request: @@ -63,11 +62,6 @@ def __init__( if self.mm_hashes: assert len(self.mm_inputs) == len(self.mm_hashes) - # Cache the computed kv block hashes of the request to avoid - # recomputing. - self._kv_block_hashes: List[BlockHashType] = [] - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - # Read-only views # Prevent directly appending to the these lists since # they should also be updated simultaneously. @@ -124,13 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens - def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: - self._kv_block_hashes = value - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - - def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: - self._kv_block_hashes.append(block_hash) - class RequestStatus(enum.IntEnum): """Status of a request."""