Skip to content

Commit f15e9b5

Browse files
youkaichaoIsotr0py
authored andcommitted
[core] clean up executor class hierarchy between v1 and v0 (vllm-project#12171)
Signed-off-by: youkaichao <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 7f3d4ae commit f15e9b5

File tree

6 files changed

+61
-798
lines changed

6 files changed

+61
-798
lines changed

vllm/executor/executor_base.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
7979
b = min([r[1] for r in results])
8080
return a, b
8181

82-
def initialize(self, num_gpu_blocks: int) -> None:
83-
"""
84-
Initialize the KV caches and begin the model execution loop of the
85-
underlying workers.
86-
For V1 compatibility.
87-
"""
88-
logger.info("# GPU blocks: %d", num_gpu_blocks)
89-
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
90-
self.collective_rpc("compile_or_warm_up_model")
91-
9282
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
9383
"""Initialize the KV cache by invoking the underlying worker.
9484
"""

vllm/v1/executor/abstract.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,92 @@
1-
from abc import ABC, abstractmethod
21
from typing import Type
32

43
from vllm.config import VllmConfig
4+
from vllm.executor.executor_base import ExecutorBase
5+
from vllm.executor.ray_distributed_executor import ( # noqa
6+
RayDistributedExecutor as RayDistributedExecutorV0)
7+
from vllm.executor.uniproc_executor import ( # noqa
8+
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
9+
from vllm.executor.uniproc_executor import ( # noqa
10+
UniProcExecutor as UniProcExecutorV0)
511
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
612
from vllm.v1.outputs import ModelRunnerOutput
713

814

9-
class Executor(ABC):
10-
"""Abstract class for executors."""
15+
class Executor(ExecutorBase):
16+
"""
17+
Abstract class for v1 executors, mainly define some methods for v1.
18+
For methods shared by v0 and v1, define them in ExecutorBase"""
1119

1220
@staticmethod
1321
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
1422
executor_class: Type[Executor]
23+
parallel_config = vllm_config.parallel_config
1524
distributed_executor_backend = (
16-
vllm_config.parallel_config.distributed_executor_backend)
25+
parallel_config.distributed_executor_backend)
26+
if distributed_executor_backend is None:
27+
# If the user does not specify the distributed executor backend,
28+
# we will choose the backend based on the world size.
29+
if parallel_config.world_size > 1:
30+
distributed_executor_backend = "mp"
31+
else:
32+
distributed_executor_backend = "uni"
33+
1734
if distributed_executor_backend == "ray":
18-
from vllm.executor.ray_distributed_executor import ( # noqa
19-
RayDistributedExecutor)
2035
executor_class = RayDistributedExecutor
2136
elif distributed_executor_backend == "mp":
2237
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
2338
executor_class = MultiprocExecutor
39+
elif distributed_executor_backend == "uni":
40+
executor_class = UniProcExecutor
41+
elif distributed_executor_backend == "external_launcher":
42+
# TODO: make v1 scheduling deterministic
43+
# to support external launcher
44+
executor_class = ExecutorWithExternalLauncher
2445
else:
25-
assert (distributed_executor_backend is None)
26-
from vllm.v1.executor.uniproc_executor import UniprocExecutor
27-
executor_class = UniprocExecutor
46+
raise ValueError("Unknown distributed executor backend: "
47+
f"{distributed_executor_backend}")
2848
return executor_class
2949

30-
@abstractmethod
31-
def __init__(self, vllm_config: VllmConfig) -> None:
32-
raise NotImplementedError
33-
34-
@abstractmethod
3550
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
36-
raise NotImplementedError
51+
"""
52+
Initialize the KV caches and begin the model execution loop of the
53+
underlying workers.
54+
"""
55+
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
56+
self.collective_rpc("compile_or_warm_up_model")
3757

38-
@abstractmethod
3958
def determine_available_memory(self) -> int: # in bytes
40-
raise NotImplementedError
59+
output = self.collective_rpc("determine_available_memory")
60+
# Since we use a shared centralized controller, we take the minimum
61+
# memory size across all workers to make sure all the memory
62+
# operators can be applied to all workers.
63+
return min(output)
4164

42-
@abstractmethod
4365
def get_kv_cache_spec(self) -> KVCacheSpec:
44-
raise NotImplementedError
66+
output = self.collective_rpc("get_kv_cache_spec")
67+
for x in output:
68+
assert x == output[0]
69+
return output[0]
4570

46-
@abstractmethod
4771
def execute_model(
4872
self,
4973
scheduler_output,
5074
) -> ModelRunnerOutput:
51-
raise NotImplementedError
75+
output = self.collective_rpc("execute_model",
76+
args=(scheduler_output, ))
77+
return output[0]
5278

53-
@abstractmethod
5479
def profile(self, is_start: bool = True):
55-
raise NotImplementedError
80+
self.collective_rpc("profile", args=(is_start, ))
81+
82+
83+
class UniProcExecutor(UniProcExecutorV0, Executor):
84+
pass
85+
86+
87+
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
88+
pass
5689

57-
@abstractmethod
58-
def shutdown(self):
59-
pass
6090

61-
@abstractmethod
62-
def check_health(self) -> None:
63-
raise NotImplementedError
91+
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
92+
pass

vllm/v1/executor/multiproc_executor.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from vllm.utils import (get_distributed_init_method, get_mp_context,
2626
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
2727
from vllm.v1.executor.abstract import Executor
28-
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
29-
from vllm.v1.outputs import ModelRunnerOutput
3028
from vllm.worker.worker_base import WorkerWrapperBase
3129

3230
logger = init_logger(__name__)
@@ -37,7 +35,7 @@
3735

3836
class MultiprocExecutor(Executor):
3937

40-
def __init__(self, vllm_config: VllmConfig) -> None:
38+
def _init_executor(self) -> None:
4139
# Call self.shutdown at exit to clean up
4240
# and ensure workers will be terminated.
4341
self._finalizer = weakref.finalize(self, self.shutdown)
@@ -55,9 +53,6 @@ def sigusr1_handler(signum, frame):
5553

5654
signal.signal(signal.SIGUSR1, sigusr1_handler)
5755

58-
self.vllm_config = vllm_config
59-
self.parallel_config = vllm_config.parallel_config
60-
6156
self.world_size = self.parallel_config.world_size
6257
tensor_parallel_size = self.parallel_config.tensor_parallel_size
6358
assert self.world_size == tensor_parallel_size, (
@@ -82,7 +77,8 @@ def sigusr1_handler(signum, frame):
8277
# Create workers
8378
self.workers: List[WorkerProcHandle] = []
8479
for rank in range(self.world_size):
85-
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
80+
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
81+
rank,
8682
distributed_init_method,
8783
scheduler_output_handle)
8884
self.workers.append(worker)
@@ -93,34 +89,6 @@ def sigusr1_handler(signum, frame):
9389
for w in self.workers:
9490
w.worker_response_mq.wait_until_ready()
9591

96-
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
97-
"""
98-
Initialize the KV caches and begin the model execution loop of the
99-
underlying workers.
100-
"""
101-
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
102-
self.collective_rpc("compile_or_warm_up_model")
103-
104-
def determine_available_memory(self) -> int:
105-
"""
106-
Determine the available memory (in bytes) for KV cache by invoking the
107-
underlying worker.
108-
"""
109-
memory_sizes = self.collective_rpc("determine_available_memory")
110-
111-
# Since we use a shared centralized controller, we take the minimum
112-
# memory size across all workers to make sure all the memory
113-
# operators can be applied to all workers.
114-
return min(memory_sizes)
115-
116-
def get_kv_cache_spec(self) -> KVCacheSpec:
117-
"""
118-
Get all kv cache needed by the model by invoking the underlying worker.
119-
"""
120-
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
121-
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
122-
return kv_cache_specs[0]
123-
12492
def collective_rpc(self,
12593
method: Union[str, Callable],
12694
timeout: Optional[float] = None,
@@ -172,18 +140,6 @@ def collective_rpc(self,
172140
# Re-raise any other exceptions
173141
raise e
174142

175-
def execute_model(
176-
self,
177-
scheduler_output,
178-
) -> ModelRunnerOutput:
179-
model_output = self.collective_rpc("execute_model",
180-
args=(scheduler_output, ))[0]
181-
return model_output
182-
183-
def profile(self, is_start: bool = True):
184-
self.collective_rpc("profile", args=(is_start, ))
185-
return
186-
187143
def _ensure_worker_termination(self):
188144
"""Ensure that all worker processes are terminated. Assumes workers have
189145
received termination requests. Waits for processing, then sends

0 commit comments

Comments
 (0)