Skip to content

Commit a4f49d5

Browse files
committed
Merge remote-tracking branch 'origin/main' into coalesce-stream
# Conflicts: # vllm/v1/engine/async_llm.py
2 parents c20997a + aea9436 commit a4f49d5

35 files changed

+424
-68
lines changed

tests/core/block/test_prefix_caching_block.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,44 @@ def test_find_cached_blocks_prefix():
796796
block_hashes=block_hashes_seq1)
797797
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
798798

799+
# Test reset prefix cache
800+
@staticmethod
801+
@pytest.mark.parametrize("num_blocks", [10])
802+
@pytest.mark.parametrize("block_size", [16])
803+
def test_reset_prefix_cache(num_blocks: int, block_size: int):
804+
"""This test case simulates the case of resetting the prefix cache."""
805+
806+
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
807+
block_size=block_size)
808+
token_ids = list(range(3 * block_size))
809+
810+
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
811+
block_size=block_size,
812+
token_ids=token_ids,
813+
allocator=allocator,
814+
)
815+
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
816+
block_size=block_size,
817+
token_ids=token_ids,
818+
allocator=allocator,
819+
)
820+
821+
# Free each block in the first chain.
822+
for block in first_chain:
823+
allocator.free(block)
824+
825+
# Failed to reset prefix cache because some blocks are not freed yet.
826+
assert not allocator.reset_prefix_cache()
827+
assert allocator.get_prefix_cache_hit_rate() > 0.0
828+
829+
# Free each block in the second chain.
830+
for block in second_chain:
831+
allocator.free(block)
832+
833+
# Reset prefix cache.
834+
assert allocator.reset_prefix_cache()
835+
assert allocator.get_prefix_cache_hit_rate() == 0.0
836+
799837
@staticmethod
800838
def create_immutable_chain(
801839
block_size: int,

tests/v1/core/test_prefix_caching.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
587587
assert {block.ref_cnt for block in block_part1[:3]} == {1}
588588
# Block 3-5 are free.
589589
assert {block.ref_cnt for block in block_part1[3:]} == {0}
590+
591+
592+
def test_reset_prefix_cache():
593+
manager = KVCacheManager(
594+
block_size=16,
595+
num_gpu_blocks=10,
596+
max_model_len=8192,
597+
sliding_window=None,
598+
enable_caching=True,
599+
num_preallocate_tokens=0,
600+
)
601+
602+
full_block_token_ids = [i for i in range(3) for _ in range(16)]
603+
unique_token_ids = [3] * 7
604+
all_token_ids = full_block_token_ids + unique_token_ids
605+
req0 = make_request("0", all_token_ids)
606+
blocks = manager.allocate_slots(req0, 55, [])
607+
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
608+
609+
unique_token_ids = [4] * 7
610+
all_token_ids = full_block_token_ids + unique_token_ids
611+
req1 = make_request("1", all_token_ids)
612+
computed_blocks, _ = manager.get_computed_blocks(req1)
613+
assert len(req1.kv_block_hashes) == 3
614+
assert len(computed_blocks) == 3
615+
blocks = manager.allocate_slots(req1, 7, computed_blocks)
616+
assert [b.block_id for b in blocks] == [4]
617+
618+
# Failed to reset prefix cache because some blocks are not freed yet.
619+
assert not manager.reset_prefix_cache()
620+
assert manager.cached_block_hash_to_block
621+
622+
# Free the blocks.
623+
manager.free(req0)
624+
manager.free(req1)
625+
626+
assert manager.reset_prefix_cache()
627+
assert not manager.cached_block_hash_to_block
628+
assert all([blk.block_hash is None for blk in manager.block_pool])

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ def __post_init__(self) -> None:
12931293
raise ValueError(f"worker-use-ray can't be used with "
12941294
f"distributed executor backend "
12951295
f"'{self.distributed_executor_backend}'.")
1296-
ray_only_devices = ["tpu", "hpu"]
1296+
ray_only_devices = ["tpu"]
12971297
from vllm.platforms import current_platform
12981298
if (current_platform.device_type in ray_only_devices
12991299
and self.world_size > 1):

vllm/core/block/cpu_gpu_block_allocator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,13 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
339339
assert device in self._allocators
340340
return self._allocators[device].get_prefix_cache_hit_rate()
341341

342+
def reset_prefix_cache(self) -> bool:
343+
"""Reset prefix cache for all devices."""
344+
success = True
345+
for allocator in self._allocators.values():
346+
success = success and allocator.reset_prefix_cache()
347+
return success
348+
342349
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
343350
"""Returns and clears the mapping of source to destination block IDs.
344351
Will be called after every swapping operations for now, and after every

vllm/core/block/interfaces.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def get_prefix_cache_hit_rate(self) -> float:
192192
"""Prefix cache hit rate. -1 means not supported or disabled."""
193193
pass
194194

195+
@abstractmethod
196+
def reset_prefix_cache(self) -> bool:
197+
"""Reset prefix cache."""
198+
pass
199+
195200
class NoFreeBlocksError(ValueError):
196201
pass
197202

@@ -297,6 +302,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
297302
"""Prefix cache hit rate. -1 means not supported or disabled."""
298303
pass
299304

305+
@abstractmethod
306+
def reset_prefix_cache(self) -> bool:
307+
"""Reset prefix cache."""
308+
pass
309+
300310
@abstractmethod
301311
def find_cached_blocks_prefix(
302312
self,

vllm/core/block/naive_block.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import deque
2-
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
2+
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
33

44
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
55
get_all_blocks_recursively)
@@ -136,16 +136,18 @@ def _allocate_block_id(self) -> BlockId:
136136
self._refcounter.incr(block_id)
137137
return block_id
138138

139-
def _free_block_id(self, block: Block) -> None:
140-
block_id = block.block_id
139+
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
140+
if isinstance(block, Block):
141+
block_id = block.block_id
142+
block.block_id = None
143+
else:
144+
block_id = block
141145
assert block_id is not None
142146

143147
refcount = self._refcounter.decr(block_id)
144148
if refcount == 0:
145149
self._free_block_indices.appendleft(block_id)
146150

147-
block.block_id = None
148-
149151
def free(self, block: Block, keep_block_object: bool = False) -> None:
150152
# Release the physical block id
151153
self._free_block_id(block)
@@ -154,6 +156,9 @@ def free(self, block: Block, keep_block_object: bool = False) -> None:
154156
if not keep_block_object:
155157
self._block_pool.free_block(block)
156158

159+
def free_block_id(self, block_id: BlockId) -> None:
160+
self._free_block_id(block_id)
161+
157162
def fork(self, last_block: Block) -> List[Block]:
158163
"""Creates a new sequence of blocks that shares the same underlying
159164
memory as the original sequence.
@@ -325,6 +330,10 @@ def swap_in(self, blocks: List[Block]) -> None:
325330
def get_prefix_cache_hit_rate(self) -> float:
326331
return -1
327332

333+
def reset_prefix_cache(self) -> bool:
334+
"""No prefix cache for naive block allocator."""
335+
return True
336+
328337
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
329338
# Not applicable for naive block allocator.
330339
return []

vllm/core/block/prefix_caching_block.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
1313
NaiveBlockAllocator)
1414
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
15+
from vllm.logger import init_logger
1516
from vllm.sequence import Sequence
1617

1718
PrefixHash = int
@@ -21,6 +22,8 @@
2122
# then we know this block hasn't been accessed yet.
2223
_DEFAULT_LAST_ACCESSED_TIME = -1
2324

25+
logger = init_logger(__name__)
26+
2427

2528
class BlockTracker:
2629
"""Used to track the status of a block inside the prefix caching allocator
@@ -105,7 +108,8 @@ def __init__(
105108

106109
# Evitor used to maintain how we want to handle those computed blocks
107110
# if we find memory pressure is high.
108-
self.evictor: Evictor = make_evictor(eviction_policy)
111+
self.eviction_policy = eviction_policy
112+
self.evictor: Evictor = make_evictor(self.eviction_policy)
109113

110114
# We share the refcounter between allocators. This allows us to promote
111115
# blocks originally allocated in the hashless allocator to immutable
@@ -428,6 +432,44 @@ def all_block_ids(self) -> FrozenSet[int]:
428432
def get_prefix_cache_hit_rate(self) -> float:
429433
return self.metric_data.get_hit_rate()
430434

435+
def reset_prefix_cache(self) -> bool:
436+
"""Reset prefix cache. This function may be used in RLHF
437+
flows to invalid prefix caching after the weights are updated,
438+
or used for resetting prefix caching status for benchmarking.
439+
440+
Returns:
441+
bool: True if the prefix cache is successfully reset,
442+
False otherwise.
443+
"""
444+
num_used_blocks = (self.get_num_total_blocks() -
445+
self.get_num_free_blocks())
446+
if num_used_blocks > 0:
447+
logger.warning(
448+
"Failed to reset prefix cache because some "
449+
"blocks (%d) are not freed yet", num_used_blocks)
450+
return False
451+
452+
# Free all blocks in the evictor.
453+
while (block_id :=
454+
self._maybe_allocate_evicted_block_id()) is not None:
455+
self._hashless_allocator.free_block_id(block_id)
456+
457+
# Should not have any cached blocks because all blocks are evicted.
458+
assert not self._cached_blocks
459+
460+
# Reset the evictor.
461+
self.evictor = make_evictor(self.eviction_policy)
462+
463+
# Reset the block tracker.
464+
for block_id in self._block_tracker:
465+
self._block_tracker[block_id] = BlockTracker()
466+
467+
# Reset the metrics.
468+
self.metric_data = CacheMetricData()
469+
470+
logger.info("Successfully reset prefix cache")
471+
return True
472+
431473
def is_block_cached(self, block: Block) -> bool:
432474
assert block.content_hash is not None
433475
return block.content_hash in self._cached_blocks

vllm/core/block_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ def get_num_free_cpu_blocks(self) -> int:
455455
def get_prefix_cache_hit_rate(self, device: Device) -> float:
456456
return self.block_allocator.get_prefix_cache_hit_rate(device)
457457

458+
def reset_prefix_cache(self) -> bool:
459+
return self.block_allocator.reset_prefix_cache()
460+
458461
def _can_swap(self,
459462
seq_group: SequenceGroup,
460463
device: Device,

vllm/core/interfaces.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
122122
"""Prefix cache hit rate. -1 means not supported or disabled."""
123123
pass
124124

125+
@abstractmethod
126+
def reset_prefix_cache(self) -> bool:
127+
"""Reset prefix cache for all devices."""
128+
pass
129+
125130
@abstractmethod
126131
def get_num_cached_tokens(self, seq: Sequence) -> int:
127132
pass

vllm/core/placeholder_block_space_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup,
9090
def get_prefix_cache_hit_rate(self, device: Device) -> float:
9191
return -1
9292

93+
def reset_prefix_cache(self) -> bool:
94+
return True
95+
9396
def get_num_cached_tokens(self, seq: Sequence) -> int:
9497
return 0

vllm/core/scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,9 @@ def has_unfinished_seqs(self) -> bool:
504504
def get_prefix_cache_hit_rate(self, device: Device) -> float:
505505
return self.block_manager.get_prefix_cache_hit_rate(device)
506506

507+
def reset_prefix_cache(self) -> bool:
508+
return self.block_manager.reset_prefix_cache()
509+
507510
def get_num_unfinished_seq_groups(self) -> int:
508511
return len(self.waiting) + len(self.running) + len(self.swapped)
509512

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
397397
'or equal to the number of GPUs available, "mp" will be used to '
398398
'keep processing on a single host. Otherwise, this will default '
399399
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
400-
'and hpu only support Ray for distributed inference.')
400+
'only supports Ray for distributed inference.')
401401

402402
parser.add_argument(
403403
'--worker-use-ray',

vllm/engine/async_llm_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,9 @@ async def start_profile(self) -> None:
11821182
async def stop_profile(self) -> None:
11831183
self.engine.stop_profile()
11841184

1185+
async def reset_prefix_cache(self) -> None:
1186+
self.engine.reset_prefix_cache()
1187+
11851188
async def add_lora(self, lora_request: LoRARequest) -> None:
11861189
self.engine.add_lora(lora_request)
11871190

vllm/engine/llm_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,14 @@ def has_unfinished_requests_for_virtual_engine(
914914
"""
915915
return self.scheduler[virtual_engine].has_unfinished_seqs()
916916

917+
def reset_prefix_cache(self) -> bool:
918+
"""Reset prefix cache for all devices."""
919+
920+
success = True
921+
for scheduler in self.scheduler:
922+
success = success and scheduler.reset_prefix_cache()
923+
return success
924+
917925
@staticmethod
918926
def _process_sequence_group_outputs(
919927
seq_group: SequenceGroup,

vllm/engine/multiprocessing/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum):
121121
STOP_PROFILE = 2
122122

123123

124+
class RPCResetPrefixCacheRequest(Enum):
125+
RESET_PREFIX_CACHE = 1
126+
127+
124128
@dataclass
125129
class RPCLoadAdapterRequest:
126130
lora_request: LoRARequest
@@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse:
134138

135139

136140
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
137-
RPCUProfileRequest, RPCLoadAdapterRequest]
141+
RPCUProfileRequest, RPCLoadAdapterRequest,
142+
RPCResetPrefixCacheRequest]
138143

139144
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
140145
RPCError]

vllm/engine/multiprocessing/client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
2828
RPCAdapterLoadedResponse, RPCError,
2929
RPCLoadAdapterRequest,
30-
RPCProcessRequest, RPCStartupRequest,
31-
RPCStartupResponse,
30+
RPCProcessRequest,
31+
RPCResetPrefixCacheRequest,
32+
RPCStartupRequest, RPCStartupResponse,
3233
RPCUProfileRequest)
3334
from vllm.engine.protocol import EngineClient
3435
# yapf: enable
@@ -675,6 +676,13 @@ async def stop_profile(self) -> None:
675676
await self._send_one_way_rpc_request(
676677
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
677678

679+
async def reset_prefix_cache(self) -> None:
680+
"""Reset the prefix cache"""
681+
682+
await self._send_one_way_rpc_request(
683+
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
684+
socket=self.input_socket)
685+
678686
async def add_lora(self, lora_request: LoRARequest) -> None:
679687
"""Load a new LoRA adapter into the engine for future requests."""
680688
# Uses the same I/O as generate requests

0 commit comments

Comments
 (0)