|
1 |
| -from abc import ABC, abstractmethod |
2 | 1 | from typing import Type
|
3 | 2 |
|
4 | 3 | from vllm.config import VllmConfig
|
| 4 | +from vllm.executor.executor_base import ExecutorBase |
| 5 | +from vllm.executor.ray_distributed_executor import ( # noqa |
| 6 | + RayDistributedExecutor as RayDistributedExecutorV0) |
| 7 | +from vllm.executor.uniproc_executor import ( # noqa |
| 8 | + ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) |
| 9 | +from vllm.executor.uniproc_executor import ( # noqa |
| 10 | + UniProcExecutor as UniProcExecutorV0) |
5 | 11 | from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
6 | 12 | from vllm.v1.outputs import ModelRunnerOutput
|
7 | 13 |
|
8 | 14 |
|
9 |
| -class Executor(ABC): |
10 |
| - """Abstract class for executors.""" |
| 15 | +class Executor(ExecutorBase): |
| 16 | + """ |
| 17 | + Abstract class for v1 executors, mainly define some methods for v1. |
| 18 | + For methods shared by v0 and v1, define them in ExecutorBase""" |
11 | 19 |
|
12 | 20 | @staticmethod
|
13 | 21 | def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
|
14 | 22 | executor_class: Type[Executor]
|
| 23 | + parallel_config = vllm_config.parallel_config |
15 | 24 | distributed_executor_backend = (
|
16 |
| - vllm_config.parallel_config.distributed_executor_backend) |
| 25 | + parallel_config.distributed_executor_backend) |
| 26 | + if distributed_executor_backend is None: |
| 27 | + # If the user does not specify the distributed executor backend, |
| 28 | + # we will choose the backend based on the world size. |
| 29 | + if parallel_config.world_size > 1: |
| 30 | + distributed_executor_backend = "mp" |
| 31 | + else: |
| 32 | + distributed_executor_backend = "uni" |
| 33 | + |
17 | 34 | if distributed_executor_backend == "ray":
|
18 |
| - from vllm.executor.ray_distributed_executor import ( # noqa |
19 |
| - RayDistributedExecutor) |
20 | 35 | executor_class = RayDistributedExecutor
|
21 | 36 | elif distributed_executor_backend == "mp":
|
22 | 37 | from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
23 | 38 | executor_class = MultiprocExecutor
|
| 39 | + elif distributed_executor_backend == "uni": |
| 40 | + executor_class = UniProcExecutor |
| 41 | + elif distributed_executor_backend == "external_launcher": |
| 42 | + # TODO: make v1 scheduling deterministic |
| 43 | + # to support external launcher |
| 44 | + executor_class = ExecutorWithExternalLauncher |
24 | 45 | else:
|
25 |
| - assert (distributed_executor_backend is None) |
26 |
| - from vllm.v1.executor.uniproc_executor import UniprocExecutor |
27 |
| - executor_class = UniprocExecutor |
| 46 | + raise ValueError("Unknown distributed executor backend: " |
| 47 | + f"{distributed_executor_backend}") |
28 | 48 | return executor_class
|
29 | 49 |
|
30 |
| - @abstractmethod |
31 |
| - def __init__(self, vllm_config: VllmConfig) -> None: |
32 |
| - raise NotImplementedError |
33 |
| - |
34 |
| - @abstractmethod |
35 | 50 | def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
36 |
| - raise NotImplementedError |
| 51 | + """ |
| 52 | + Initialize the KV caches and begin the model execution loop of the |
| 53 | + underlying workers. |
| 54 | + """ |
| 55 | + self.collective_rpc("initialize_cache", args=(kv_cache_config, )) |
| 56 | + self.collective_rpc("compile_or_warm_up_model") |
37 | 57 |
|
38 |
| - @abstractmethod |
39 | 58 | def determine_available_memory(self) -> int: # in bytes
|
40 |
| - raise NotImplementedError |
| 59 | + output = self.collective_rpc("determine_available_memory") |
| 60 | + # Since we use a shared centralized controller, we take the minimum |
| 61 | + # memory size across all workers to make sure all the memory |
| 62 | + # operators can be applied to all workers. |
| 63 | + return min(output) |
41 | 64 |
|
42 |
| - @abstractmethod |
43 | 65 | def get_kv_cache_spec(self) -> KVCacheSpec:
|
44 |
| - raise NotImplementedError |
| 66 | + output = self.collective_rpc("get_kv_cache_spec") |
| 67 | + for x in output: |
| 68 | + assert x == output[0] |
| 69 | + return output[0] |
45 | 70 |
|
46 |
| - @abstractmethod |
47 | 71 | def execute_model(
|
48 | 72 | self,
|
49 | 73 | scheduler_output,
|
50 | 74 | ) -> ModelRunnerOutput:
|
51 |
| - raise NotImplementedError |
| 75 | + output = self.collective_rpc("execute_model", |
| 76 | + args=(scheduler_output, )) |
| 77 | + return output[0] |
52 | 78 |
|
53 |
| - @abstractmethod |
54 | 79 | def profile(self, is_start: bool = True):
|
55 |
| - raise NotImplementedError |
| 80 | + self.collective_rpc("profile", args=(is_start, )) |
| 81 | + |
| 82 | + |
| 83 | +class UniProcExecutor(UniProcExecutorV0, Executor): |
| 84 | + pass |
| 85 | + |
| 86 | + |
| 87 | +class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): |
| 88 | + pass |
56 | 89 |
|
57 |
| - @abstractmethod |
58 |
| - def shutdown(self): |
59 |
| - pass |
60 | 90 |
|
61 |
| - @abstractmethod |
62 |
| - def check_health(self) -> None: |
63 |
| - raise NotImplementedError |
| 91 | +class RayDistributedExecutor(RayDistributedExecutorV0, Executor): |
| 92 | + pass |
0 commit comments