Skip to content

Commit 87a0c07

Browse files
authored
[core] allow callable in collective_rpc (#12151)
Signed-off-by: youkaichao <[email protected]>
1 parent d4e6194 commit 87a0c07

13 files changed

+147
-50
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ steps:
107107
source_file_dependencies:
108108
- vllm/
109109
commands:
110-
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
110+
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
111111
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
112112
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
113113
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
@@ -466,7 +466,9 @@ steps:
466466
- vllm/worker/worker_base.py
467467
- vllm/worker/worker.py
468468
- vllm/worker/model_runner.py
469+
- entrypoints/llm/test_collective_rpc.py
469470
commands:
471+
- pytest -v -s entrypoints/llm/test_collective_rpc.py
470472
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
471473
- pytest -v -s ./compile/test_basic_correctness.py
472474
- pytest -v -s ./compile/test_wrapper.py

tests/engine/test_custom_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import os
3-
from typing import Any, Dict, List, Optional, Tuple
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
44

55
import pytest
66

@@ -18,7 +18,7 @@ class Mock:
1818
class CustomUniExecutor(UniProcExecutor):
1919

2020
def collective_rpc(self,
21-
method: str,
21+
method: Union[str, Callable],
2222
timeout: Optional[float] = None,
2323
args: Tuple = (),
2424
kwargs: Optional[Dict] = None) -> List[Any]:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
from vllm import LLM
4+
5+
from ...utils import fork_new_process_for_each_test
6+
7+
8+
@pytest.mark.parametrize("tp_size", [1, 2])
9+
@pytest.mark.parametrize("backend", ["mp", "ray"])
10+
@fork_new_process_for_each_test
11+
def test_collective_rpc(tp_size, backend):
12+
if tp_size == 1 and backend == "ray":
13+
pytest.skip("Skip duplicate test case")
14+
if tp_size == 1:
15+
backend = None
16+
17+
# intentionally define the method and class in the test function,
18+
# to test if they can be serialized and sent to the workers
19+
def echo_rank(self):
20+
return self.rank
21+
22+
from vllm.worker.worker import Worker
23+
24+
class MyWorker(Worker):
25+
26+
def echo_rank(self):
27+
return self.rank
28+
29+
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
30+
enforce_eager=True,
31+
load_format="dummy",
32+
tensor_parallel_size=tp_size,
33+
distributed_executor_backend=backend,
34+
worker_cls=MyWorker)
35+
for method in ["echo_rank", echo_rank]:
36+
assert llm.collective_rpc(method) == list(range(tp_size))

vllm/engine/llm_engine.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from contextlib import contextmanager
66
from dataclasses import dataclass
77
from functools import partial
8-
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
9-
List, Mapping, NamedTuple, Optional)
8+
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
9+
Iterable, List, Mapping, NamedTuple, Optional)
1010
from typing import Sequence as GenericSequence
11-
from typing import Set, Type, Union, cast, overload
11+
from typing import Set, Tuple, Type, Union, cast, overload
1212

1313
import torch
1414
from typing_extensions import TypeVar, deprecated
@@ -1816,6 +1816,17 @@ def start_profile(self) -> None:
18161816
def stop_profile(self) -> None:
18171817
self.model_executor.stop_profile()
18181818

1819+
def collective_rpc(self,
1820+
method: Union[str, Callable],
1821+
timeout: Optional[float] = None,
1822+
args: Tuple = (),
1823+
kwargs: Optional[Dict] = None) -> List[Any]:
1824+
"""
1825+
See LLM.collective_rpc for more details.
1826+
"""
1827+
return self.model_executor.collective_rpc(method, timeout, args,
1828+
kwargs)
1829+
18191830
def check_health(self) -> None:
18201831
if self.tokenizer:
18211832
self.tokenizer.check_health()

vllm/entrypoints/llm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import itertools
22
import warnings
33
from contextlib import contextmanager
4-
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
5-
Union, cast, overload)
4+
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
5+
Tuple, Type, Union, cast, overload)
66

77
import cloudpickle
88
from tqdm import tqdm
@@ -464,7 +464,7 @@ def generate(
464464
return self.engine_class.validate_outputs(outputs, RequestOutput)
465465

466466
def collective_rpc(self,
467-
method: str,
467+
method: Union[str, Callable],
468468
timeout: Optional[float] = None,
469469
args: Tuple = (),
470470
kwargs: Optional[Dict] = None) -> List[Any]:
@@ -476,9 +476,13 @@ def collective_rpc(self,
476476
Then, users can call the new methods through this API.
477477
It is recommended to use this API to only pass control messages,
478478
and set up data-plane communication to pass data.
479+
The method can also be a callable, which will be serialized
480+
and sent to all workers to execute.
481+
If the method is a callable, it should accept an additional
482+
`self` argument, in addition to the arguments passed in `args`
483+
and `kwargs`. The `self` argument will be the worker object.
479484
"""
480-
return self.llm_engine.model_executor.collective_rpc(
481-
method, timeout, args, kwargs)
485+
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
482486

483487
def beam_search(
484488
self,

vllm/executor/executor_base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from abc import ABC, abstractmethod
3-
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
3+
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
4+
Union)
45

56
from vllm.config import VllmConfig
67
from vllm.logger import init_logger
@@ -47,7 +48,7 @@ def _init_executor(self) -> None:
4748

4849
@abstractmethod
4950
def collective_rpc(self,
50-
method: str,
51+
method: Union[str, Callable],
5152
timeout: Optional[float] = None,
5253
args: Tuple = (),
5354
kwargs: Optional[Dict] = None) -> List[Any]:
@@ -260,7 +261,7 @@ def _driver_execute_model(
260261
raise NotImplementedError
261262

262263
def collective_rpc(self,
263-
method: str,
264+
method: Union[str, Callable],
264265
timeout: Optional[float] = None,
265266
args: Tuple = (),
266267
kwargs: Optional[Dict] = None) -> List[Any]:
@@ -269,7 +270,7 @@ def collective_rpc(self,
269270
@abstractmethod
270271
def _run_workers(
271272
self,
272-
method: str,
273+
method: Union[str, Callable],
273274
*args,
274275
async_run_tensor_parallel_workers_only: bool = False,
275276
max_concurrent_workers: Optional[int] = None,

vllm/executor/mp_distributed_executor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2-
from typing import Any, List, Optional
2+
from typing import Any, Callable, List, Optional, Union
3+
4+
import cloudpickle
35

46
from vllm.executor.executor_base import DistributedExecutorBase
57
from vllm.executor.multiproc_worker_utils import (
@@ -9,7 +11,7 @@
911
from vllm.model_executor.layers.sampler import SamplerOutput
1012
from vllm.sequence import ExecuteModelRequest
1113
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
12-
get_ip, get_open_port, make_async)
14+
get_ip, get_open_port, make_async, run_method)
1315
from vllm.worker.worker_base import WorkerWrapperBase
1416

1517
logger = init_logger(__name__)
@@ -107,7 +109,7 @@ def _driver_execute_model(
107109

108110
def _run_workers(
109111
self,
110-
method: str,
112+
method: Union[str, Callable],
111113
*args,
112114
async_run_tensor_parallel_workers_only: bool = False,
113115
max_concurrent_workers: Optional[int] = None,
@@ -121,6 +123,11 @@ def _run_workers(
121123
It will also be run asynchronously and return a list of futures
122124
rather than blocking on the results.
123125
"""
126+
if isinstance(method, str):
127+
sent_method = method
128+
else:
129+
sent_method = cloudpickle.dumps(method)
130+
del method
124131

125132
if max_concurrent_workers:
126133
raise NotImplementedError(
@@ -129,18 +136,18 @@ def _run_workers(
129136
if async_run_tensor_parallel_workers_only:
130137
# Run only non-driver workers and just return futures.
131138
return [
132-
worker.execute_method(method, *args, **kwargs)
139+
worker.execute_method(sent_method, *args, **kwargs)
133140
for worker in self.non_driver_workers
134141
]
135142

136143
# Start all remote workers first.
137144
worker_outputs = [
138-
worker.execute_method(method, *args, **kwargs)
145+
worker.execute_method(sent_method, *args, **kwargs)
139146
for worker in self.workers
140147
]
141148

142-
driver_worker_method = getattr(self.driver_worker, method)
143-
driver_worker_output = driver_worker_method(*args, **kwargs)
149+
driver_worker_output = run_method(self.driver_worker, sent_method,
150+
args, kwargs)
144151

145152
# Get the results of the workers.
146153
return [driver_worker_output

vllm/executor/multiproc_worker_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from vllm.config import VllmConfig
1616
from vllm.logger import init_logger
1717
from vllm.triton_utils.importing import HAS_TRITON
18-
from vllm.utils import _check_multiproc_method, get_mp_context
18+
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
1919

2020
if HAS_TRITON:
2121
from vllm.triton_utils import maybe_set_triton_cache_manager
@@ -169,7 +169,7 @@ def __init__(self, result_handler: ResultHandler,
169169
self.process.start()
170170

171171
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
172-
method: str, args, kwargs):
172+
method: Union[str, bytes], args, kwargs):
173173
task_id = uuid.uuid4()
174174
self.tasks[task_id] = future
175175
try:
@@ -180,12 +180,13 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
180180
del self.tasks[task_id]
181181
raise ChildProcessError("worker died") from e
182182

183-
def execute_method(self, method: str, *args, **kwargs):
183+
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
184184
future: ResultFuture = ResultFuture()
185185
self._enqueue_task(future, method, args, kwargs)
186186
return future
187187

188-
async def execute_method_async(self, method: str, *args, **kwargs):
188+
async def execute_method_async(self, method: Union[str, bytes], *args,
189+
**kwargs):
189190
future = asyncio.get_running_loop().create_future()
190191
self._enqueue_task(future, method, args, kwargs)
191192
return await future
@@ -230,8 +231,7 @@ def _run_worker_process(
230231
exception = None
231232
task_id, method, args, kwargs = items
232233
try:
233-
executor = getattr(worker, method)
234-
output = executor(*args, **kwargs)
234+
output = run_method(worker, method, args, kwargs)
235235
except SystemExit:
236236
raise
237237
except KeyboardInterrupt:

vllm/executor/ray_distributed_executor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import os
33
from collections import defaultdict
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
5+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
66

7+
import cloudpickle
78
import msgspec
89

910
import vllm.envs as envs
@@ -410,7 +411,7 @@ def execute_model(
410411

411412
def _run_workers(
412413
self,
413-
method: str,
414+
method: Union[str, Callable],
414415
*args,
415416
async_run_tensor_parallel_workers_only: bool = False,
416417
max_concurrent_workers: Optional[int] = None,
@@ -426,6 +427,11 @@ def _run_workers(
426427
rather than blocking on the results.
427428
- args/kwargs: All workers share the same args/kwargs
428429
"""
430+
if isinstance(method, str):
431+
sent_method = method
432+
else:
433+
sent_method = cloudpickle.dumps(method)
434+
del method
429435
if self.use_ray_spmd_worker:
430436
assert not async_run_tensor_parallel_workers_only, (
431437
"async_run_tensor_parallel_workers_only is not supported for "
@@ -440,7 +446,7 @@ def _run_workers(
440446
if async_run_tensor_parallel_workers_only:
441447
ray_workers = self.non_driver_workers
442448
ray_worker_outputs = [
443-
worker.execute_method.remote(method, *args, **kwargs)
449+
worker.execute_method.remote(sent_method, *args, **kwargs)
444450
for worker in ray_workers
445451
]
446452

@@ -455,7 +461,7 @@ def _run_workers(
455461
if not self.use_ray_spmd_worker:
456462
# Start the driver worker after all the ray workers.
457463
driver_worker_output = [
458-
self.driver_worker.execute_method(method, *args, **kwargs)
464+
self.driver_worker.execute_method(sent_method, *args, **kwargs)
459465
]
460466

461467
# Get the results of the ray workers.

vllm/executor/uniproc_executor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Tuple
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
55
import torch.distributed as dist
66

77
import vllm.envs as envs
88
from vllm.executor.executor_base import ExecutorBase
99
from vllm.logger import init_logger
10-
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
10+
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
11+
run_method)
1112
from vllm.worker.worker_base import WorkerWrapperBase
1213

1314
logger = init_logger(__name__)
@@ -39,18 +40,13 @@ def _init_executor(self) -> None:
3940
self.collective_rpc("load_model")
4041

4142
def collective_rpc(self,
42-
method: str,
43+
method: Union[str, Callable],
4344
timeout: Optional[float] = None,
4445
args: Tuple = (),
4546
kwargs: Optional[Dict] = None) -> List[Any]:
4647
if kwargs is None:
4748
kwargs = {}
48-
try:
49-
func = getattr(self.driver_worker, method)
50-
except AttributeError:
51-
raise NotImplementedError(f"Method {method} is not implemented.") \
52-
from None
53-
answer = func(*args, **kwargs)
49+
answer = run_method(self.driver_worker, method, args, kwargs)
5450
return [answer]
5551

5652
def check_health(self) -> None:

0 commit comments

Comments
 (0)