Skip to content

Commit eac7762

Browse files
WoosukKwonLeiWang1999
authored andcommitted
[Hardware][TPU] Implement tensor parallelism with Ray (vllm-project#5871)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent 24a3cbb commit eac7762

File tree

6 files changed

+365
-21
lines changed

6 files changed

+365
-21
lines changed

requirements-tpu.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
# Dependencies for TPU
55
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
66
# You can install the dependencies in Dockerfile.tpu.
7+
ray
78
triton # To avoid import errors

vllm/attention/backends/pallas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ class PallasMetadata(AttentionMetadata):
5555

5656
# Currently, input sequences can only contain all prefills
5757
# or all decoding.
58-
block_tables: Optional[torch.Tensor]
59-
context_lens: Optional[torch.Tensor]
58+
block_tables: Optional[torch.Tensor] = None
59+
context_lens: Optional[torch.Tensor] = None
6060

6161
@property
6262
def prefill_metadata(self) -> Optional["PallasMetadata"]:

vllm/engine/llm_engine.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,14 @@ def _get_executor_cls(cls,
394394
from vllm.executor.neuron_executor import NeuronExecutor
395395
executor_class = NeuronExecutor
396396
elif engine_config.device_config.device_type == "tpu":
397-
from vllm.executor.tpu_executor import TPUExecutor
398-
executor_class = TPUExecutor
397+
if distributed_executor_backend == "ray":
398+
initialize_ray_cluster(engine_config.parallel_config)
399+
from vllm.executor.ray_tpu_executor import RayTPUExecutor
400+
executor_class = RayTPUExecutor
401+
else:
402+
assert distributed_executor_backend is None
403+
from vllm.executor.tpu_executor import TPUExecutor
404+
executor_class = TPUExecutor
399405
elif engine_config.device_config.device_type == "cpu":
400406
from vllm.executor.cpu_executor import CPUExecutor
401407
executor_class = CPUExecutor

vllm/executor/ray_tpu_executor.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import asyncio
2+
import os
3+
from collections import defaultdict
4+
from itertools import islice, repeat
5+
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
6+
Union)
7+
8+
import vllm.envs as envs
9+
from vllm.executor.executor_base import ExecutorAsyncBase
10+
from vllm.executor.ray_utils import RayWorkerWrapper, ray
11+
from vllm.executor.tpu_executor import TPUExecutor
12+
from vllm.logger import init_logger
13+
from vllm.sequence import ExecuteModelRequest, SamplerOutput
14+
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
15+
get_vllm_instance_id, make_async)
16+
17+
if ray is not None:
18+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
19+
20+
if TYPE_CHECKING:
21+
from ray.util.placement_group import PlacementGroup
22+
23+
logger = init_logger(__name__)
24+
25+
26+
class RayTPUExecutor(TPUExecutor):
27+
28+
def __init__(self, *args, **kwargs):
29+
# This is non-None when the execute model loop is running
30+
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
31+
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
32+
# Updated by implementations that require additional args to be passed
33+
# to the _run_workers execute_model call
34+
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
35+
36+
super().__init__(*args, **kwargs)
37+
38+
def _init_executor(self) -> None:
39+
assert self.parallel_config.distributed_executor_backend == "ray"
40+
placement_group = self.parallel_config.placement_group
41+
42+
# Disable Ray usage stats collection.
43+
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
44+
if ray_usage != "1":
45+
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
46+
47+
# Create the parallel TPU workers.
48+
self._init_workers_ray(placement_group)
49+
50+
def _init_workers_ray(self, placement_group: "PlacementGroup",
51+
**ray_remote_kwargs):
52+
# The driver dummy worker does not actually use any resources.
53+
# It holds the resource for the driver worker.
54+
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
55+
# The remaining workers are the actual ray actors.
56+
self.workers: List[RayWorkerWrapper] = []
57+
58+
# Create the workers.
59+
driver_ip = get_ip()
60+
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
61+
if not bundle.get("TPU", 0):
62+
continue
63+
scheduling_strategy = PlacementGroupSchedulingStrategy(
64+
placement_group=placement_group,
65+
placement_group_capture_child_tasks=True,
66+
placement_group_bundle_index=bundle_id,
67+
)
68+
69+
assert self.speculative_config is None
70+
worker_module_name = "vllm.worker.tpu_worker"
71+
worker_class_name = "TPUWorker"
72+
73+
worker = ray.remote(
74+
num_cpus=0,
75+
resources={"TPU": 1},
76+
scheduling_strategy=scheduling_strategy,
77+
**ray_remote_kwargs,
78+
)(RayWorkerWrapper).remote(
79+
worker_module_name=worker_module_name,
80+
worker_class_name=worker_class_name,
81+
trust_remote_code=self.model_config.trust_remote_code,
82+
)
83+
84+
worker_ip = ray.get(worker.get_node_ip.remote())
85+
if worker_ip == driver_ip and self.driver_dummy_worker is None:
86+
# If the worker is on the same node as the driver, we use it
87+
# as the resource holder for the driver process.
88+
self.driver_dummy_worker = worker
89+
self.driver_worker = RayWorkerWrapper(
90+
worker_module_name=worker_module_name,
91+
worker_class_name=worker_class_name,
92+
trust_remote_code=self.model_config.trust_remote_code,
93+
)
94+
else:
95+
# Else, added to the list of workers.
96+
self.workers.append(worker)
97+
98+
if self.driver_dummy_worker is None:
99+
raise ValueError(
100+
"Ray does not allocate any TPUs on the driver node. Consider "
101+
"adjusting the Ray placement group or running the driver on a "
102+
"TPU node.")
103+
104+
# Get the set of TPU IDs used on each node.
105+
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
106+
use_dummy_driver=True)
107+
108+
node_workers = defaultdict(list)
109+
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
110+
node_workers[node_id].append(i)
111+
112+
VLLM_INSTANCE_ID = get_vllm_instance_id()
113+
114+
# Set environment variables for the driver and workers.
115+
all_args_to_update_environment_variables = [({
116+
"VLLM_INSTANCE_ID":
117+
VLLM_INSTANCE_ID,
118+
"VLLM_TRACE_FUNCTION":
119+
str(envs.VLLM_TRACE_FUNCTION),
120+
}, ) for _ in worker_node_and_gpu_ids]
121+
self._run_workers("update_environment_variables",
122+
all_args=all_args_to_update_environment_variables)
123+
124+
if len(node_workers) == 1:
125+
# in single node case, we don't need to get the IP address.
126+
# the loopback address is sufficient
127+
# NOTE: a node may have several IP addresses, one for each
128+
# network interface. `get_ip()` might return any of them,
129+
# while they might not work for communication inside the node
130+
# if the network setup is complicated. Using the loopback address
131+
# solves this issue, as it always works for communication inside
132+
# the node.
133+
driver_ip = "127.0.0.1"
134+
distributed_init_method = get_distributed_init_method(
135+
driver_ip, get_open_port())
136+
137+
# Initialize the actual workers inside worker wrapper.
138+
init_worker_all_kwargs = [
139+
self._get_worker_kwargs(
140+
local_rank=node_workers[node_id].index(rank),
141+
rank=rank,
142+
distributed_init_method=distributed_init_method,
143+
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
144+
]
145+
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
146+
147+
self._run_workers("init_device")
148+
self._run_workers("load_model",
149+
max_concurrent_workers=self.parallel_config.
150+
max_parallel_loading_workers)
151+
152+
def _driver_execute_model(
153+
self,
154+
execute_model_req: Optional[ExecuteModelRequest] = None
155+
) -> List[SamplerOutput]:
156+
"""Run execute_model in the driver worker.
157+
158+
Passing None will cause the driver to stop the model execution
159+
loop running in each of the remote workers.
160+
"""
161+
return self.driver_worker.execute_method("execute_model",
162+
execute_model_req)
163+
164+
def _run_workers(
165+
self,
166+
method: str,
167+
*args,
168+
async_run_remote_workers_only: bool = False,
169+
all_args: Optional[List[Tuple[Any, ...]]] = None,
170+
all_kwargs: Optional[List[Dict[str, Any]]] = None,
171+
use_dummy_driver: bool = False,
172+
max_concurrent_workers: Optional[int] = None,
173+
use_ray_compiled_dag: bool = False,
174+
**kwargs,
175+
) -> Any:
176+
"""Runs the given method on all workers. Can be used in the following
177+
ways:
178+
179+
- async_run_remote_workers_only: If True the method will be run only
180+
in the remote workers, not the driver worker. It will also be
181+
run asynchronously and return a list of futures rather than blocking
182+
on the results.
183+
- args/kwargs: All workers share the same args/kwargs
184+
- all_args/all_kwargs: args/kwargs for each worker are specified
185+
individually
186+
"""
187+
188+
if max_concurrent_workers:
189+
raise NotImplementedError(
190+
"max_concurrent_workers is not supported yet.")
191+
192+
count = len(self.workers)
193+
all_worker_args = repeat(args, count) if all_args is None \
194+
else islice(all_args, 1, None)
195+
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
196+
else islice(all_kwargs, 1, None)
197+
198+
# Start the ray workers first.
199+
ray_worker_outputs = [
200+
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
201+
for (worker, worker_args, worker_kwargs
202+
) in zip(self.workers, all_worker_args, all_worker_kwargs)
203+
]
204+
205+
if async_run_remote_workers_only:
206+
# Just return futures
207+
return ray_worker_outputs
208+
209+
driver_args = args if all_args is None else all_args[0]
210+
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
211+
212+
# Start the driver worker after all the ray workers.
213+
if not use_dummy_driver:
214+
driver_worker_output = self.driver_worker.execute_method(
215+
method, *driver_args, **driver_kwargs)
216+
else:
217+
assert self.driver_dummy_worker is not None
218+
driver_worker_output = ray.get(
219+
self.driver_dummy_worker.execute_method.remote(
220+
method, *driver_args, **driver_kwargs))
221+
# Get the results of the ray workers.
222+
if self.workers:
223+
ray_worker_outputs = ray.get(ray_worker_outputs)
224+
225+
return [driver_worker_output] + ray_worker_outputs
226+
227+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
228+
"""Wait for futures returned from _run_workers() with
229+
async_run_remote_workers_only to complete."""
230+
ray.get(parallel_worker_tasks)
231+
232+
def determine_num_available_blocks(self) -> Tuple[int, int]:
233+
num_blocks = self._run_workers("determine_num_available_blocks", )
234+
num_tpu_blocks = min(b[0] for b in num_blocks)
235+
num_cpu_blocks = min(b[1] for b in num_blocks)
236+
return num_tpu_blocks, num_cpu_blocks
237+
238+
def initialize_cache(self, num_gpu_blocks: int,
239+
num_cpu_blocks: int) -> None:
240+
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
241+
num_cpu_blocks)
242+
self.cache_config.num_gpu_blocks = num_gpu_blocks
243+
self.cache_config.num_cpu_blocks = num_cpu_blocks
244+
self._run_workers("initialize_cache",
245+
num_gpu_blocks=num_gpu_blocks,
246+
num_cpu_blocks=num_cpu_blocks)
247+
248+
def execute_model(
249+
self,
250+
execute_model_req: ExecuteModelRequest,
251+
) -> List[SamplerOutput]:
252+
if self.parallel_worker_tasks is None:
253+
self.parallel_worker_tasks = self._run_workers(
254+
"start_worker_execution_loop",
255+
async_run_remote_workers_only=True,
256+
**self.extra_execute_model_run_workers_kwargs)
257+
258+
# Only the driver worker returns the sampling results.
259+
return self._driver_execute_model(execute_model_req)
260+
261+
def stop_remote_worker_execution_loop(self) -> None:
262+
if self.parallel_worker_tasks is None:
263+
return
264+
265+
self._driver_execute_model()
266+
parallel_worker_tasks = self.parallel_worker_tasks
267+
self.parallel_worker_tasks = None
268+
# Ensure that workers exit model loop cleanly
269+
# (this will raise otherwise)
270+
self._wait_for_tasks_completion(parallel_worker_tasks)
271+
272+
273+
class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
274+
275+
def __init__(self, *args, **kwargs):
276+
super().__init__(*args, **kwargs)
277+
self.driver_exec_method = make_async(self.driver_worker.execute_method)
278+
279+
async def execute_model_async(
280+
self,
281+
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
282+
if self.parallel_worker_tasks is None:
283+
# Start model execution loop running in the parallel workers
284+
self.parallel_worker_tasks = asyncio.create_task(
285+
self._start_worker_execution_loop())
286+
287+
# Only the driver worker returns the sampling results.
288+
return await self._driver_execute_model_async(execute_model_req)
289+
290+
async def stop_remote_worker_execution_loop_async(self) -> None:
291+
if self.parallel_worker_tasks is None:
292+
return
293+
294+
await self._driver_execute_model_async()
295+
parallel_worker_tasks = self.parallel_worker_tasks
296+
self.parallel_worker_tasks = None
297+
# Ensure that workers exit model loop cleanly
298+
# (this will raise otherwise)
299+
await parallel_worker_tasks
300+
301+
async def _driver_execute_model_async(
302+
self,
303+
execute_model_req: Optional[ExecuteModelRequest] = None
304+
) -> List[SamplerOutput]:
305+
return await self.driver_exec_method("execute_model",
306+
execute_model_req)
307+
308+
async def _start_worker_execution_loop(self):
309+
coros = [
310+
worker.execute_method.remote("start_worker_execution_loop")
311+
for worker in self.workers
312+
]
313+
return await asyncio.gather(*coros)

0 commit comments

Comments
 (0)