@@ -98,6 +98,13 @@ def __init__(
98
98
self .req_to_blocks : DefaultDict [str , ReqKVCacheBlocks ] = defaultdict (
99
99
lambda : [[] for _ in range (self .num_kv_cache_groups )])
100
100
101
+ # Mapping from request ID to kv block hashes.
102
+ # This is to avoid recomputing the block hashes for each call of
103
+ # `get_computed_blocks` or `allocate_slots`.
104
+ self .req_to_block_hashes : DefaultDict [
105
+ str , List [List [BlockHashType ]]] = defaultdict (
106
+ lambda : [[] for _ in range (self .num_kv_cache_groups )])
107
+
101
108
@property
102
109
def usage (self ) -> float :
103
110
return 1.0 - (self .free_block_queue .num_free_blocks /
@@ -121,17 +128,19 @@ def get_computed_blocks(self,
121
128
return [[] for _ in self .managers ], 0
122
129
123
130
# The block hashes for the request may already be computed
124
- # if the request was preempted and resumed.
125
- if not request .kv_block_hashes :
126
- request .set_kv_block_hashes ([
131
+ # if the scheduler has tried to schedule the request before.
132
+ block_hashes = self .req_to_block_hashes [request .request_id ]
133
+ if not block_hashes :
134
+ block_hashes = [
127
135
hash_request_tokens (manager .block_size , request , i )
128
136
for i , manager in enumerate (self .managers )
129
- ])
137
+ ]
138
+ self .req_to_block_hashes [request .request_id ] = block_hashes
130
139
131
140
computed_blocks : ReqKVCacheBlocks = [] # computed blocks of each group
132
141
prefix_length : List [PrefixLength ] = [
133
142
] # possible cached prefix length of each group
134
- block_hashes = request . kv_block_hashes
143
+
135
144
for i , manager in enumerate (self .managers ):
136
145
prefix_length_i , computed_blocks_i = (
137
146
manager .get_possible_cached_prefix (block_hashes [i ]))
@@ -154,7 +163,6 @@ def get_computed_blocks(self,
154
163
for i , manager in enumerate (self .managers ):
155
164
computed_blocks [i ] = computed_blocks [i ][:num_computed_tokens //
156
165
manager .block_size ]
157
-
158
166
return computed_blocks , num_computed_tokens
159
167
160
168
def allocate_slots (
@@ -560,8 +568,8 @@ def _cache_full_blocks(
560
568
prev_block: The previous block in the chain.
561
569
kv_cache_group_id: The KV cache group that the blocks belong to
562
570
"""
563
- num_cached_block_hashes = len (
564
- request . kv_block_hashes [kv_cache_group_id ])
571
+ block_hashes = self . req_to_block_hashes [ request . request_id ]
572
+ num_cached_block_hashes = len ( block_hashes [kv_cache_group_id ])
565
573
566
574
# Update the new blocks with the block hashes through the chain.
567
575
prev_block_hash_value = None
@@ -596,8 +604,7 @@ def _cache_full_blocks(
596
604
# this request (either the prompt tokens or the previously
597
605
# generated tokens with preemption). In this case we simply
598
606
# reuse the block hash.
599
- block_hash = request .kv_block_hashes [kv_cache_group_id ][
600
- blk_idx ]
607
+ block_hash = block_hashes [kv_cache_group_id ][blk_idx ]
601
608
else :
602
609
# Otherwise compute the block hash and cache it in the request
603
610
# in case it will be preempted in the future.
@@ -620,13 +627,21 @@ def _cache_full_blocks(
620
627
block_hash = hash_block_tokens (prev_block_hash_value ,
621
628
block_tokens , kv_cache_group_id ,
622
629
extra_keys )
623
- request . append_kv_block_hashes (kv_cache_group_id , block_hash )
630
+ block_hashes . append (kv_cache_group_id , block_hash )
624
631
625
632
# Update and added the full block to the cache.
626
633
blk .block_hash = block_hash
627
634
self .cached_block_hash_to_block [block_hash ][blk .block_id ] = blk
628
635
prev_block_hash_value = block_hash .hash_value
629
636
637
+ def free_block_hashes (self , request : Request ) -> None :
638
+ """Discard the block hashes for the request.
639
+
640
+ NOTE: Unlike `free`, this method should be called only when the request
641
+ is finished, not when it is preempted.
642
+ """
643
+ self .req_to_block_hashes .pop (request .request_id , None )
644
+
630
645
def get_null_block (self ) -> KVCacheBlock :
631
646
return self ._null_block
632
647
0 commit comments