Skip to content

[Hardware][TPU] Implement tensor parallelism with Ray #5871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 72 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
76fc072
Add & warnings
WoosukKwon Jun 24, 2024
27a5ad8
Add in dummy_run
WoosukKwon Jun 24, 2024
5ab6f65
Add is_driver_worker
WoosukKwon Jun 24, 2024
c4e79a0
Make TPUExecutor similar to GPUExecutor
WoosukKwon Jun 24, 2024
ff81993
Add multiprocessing-based TPU executor
WoosukKwon Jun 24, 2024
16e80b2
Use TPU to initialize Ray cluster
WoosukKwon Jun 24, 2024
05884ce
Add pjrt proc init
WoosukKwon Jun 24, 2024
20d23eb
Add Ray TPU executor
WoosukKwon Jun 24, 2024
5d4df21
Use Ray TPU executor for tp
WoosukKwon Jun 24, 2024
6b2c76c
Minor
WoosukKwon Jun 24, 2024
d91446b
Fix TPUWorker.execute_model
WoosukKwon Jun 24, 2024
ab1595d
Add is_driver_worker & input broadcast
WoosukKwon Jun 24, 2024
4b45393
Call xm._init_world_size_ordinal
WoosukKwon Jun 24, 2024
86451a2
Bug fix on vocab
WoosukKwon Jun 24, 2024
0539299
Use all gather for TPU
WoosukKwon Jun 24, 2024
b35917c
Support TPU in GroupCoordinator
WoosukKwon Jun 24, 2024
b9a84bc
Delete multiproc TPU executor
WoosukKwon Jun 25, 2024
c756b76
Minor
WoosukKwon Jun 25, 2024
16e9934
[Bugfix][TPU] Fix CPU cache allocation & swapping
WoosukKwon Jun 26, 2024
e25f470
Merge branch 'fix-tpu-swpa' into tpu-n
WoosukKwon Jun 26, 2024
ca6d1d6
yapf
WoosukKwon Jun 26, 2024
cd4f68d
Add Ray to TPU dependency
WoosukKwon Jun 26, 2024
5df4164
Merge branch 'main' into tpu-n
WoosukKwon Jun 26, 2024
546987a
Fix
WoosukKwon Jun 26, 2024
330be6e
Fix
WoosukKwon Jun 26, 2024
b45ed24
Merge branch 'main' into tpu-n
WoosukKwon Jun 29, 2024
8fab9fd
Add use_all_gather to LoRA
WoosukKwon Jun 29, 2024
c4cbe9f
Fix
WoosukKwon Jun 29, 2024
2871c7c
Merge branch 'main' into tpu-n
WoosukKwon Jun 30, 2024
db7adc7
Add an assert for dim == -1
WoosukKwon Jun 30, 2024
696790d
is_tpu -> use_xla
WoosukKwon Jun 30, 2024
8a08896
Merge branch 'main' into tpu-n
WoosukKwon Jun 30, 2024
36f9070
Merge branch 'main' into tpu-n
WoosukKwon Jul 1, 2024
28afe56
yapf
WoosukKwon Jul 2, 2024
60bf64d
Add hack in vocab
WoosukKwon Jul 2, 2024
0fbb050
Merge branch 'main' into tpu-n
WoosukKwon Jul 7, 2024
ddf4cbe
Merge branch 'main' into tpu-n
WoosukKwon Jul 7, 2024
cd4842d
Fix multi-modal support
WoosukKwon Jul 9, 2024
54e637b
Merge branch 'main' into tpu-n
WoosukKwon Jul 9, 2024
73ed611
Merge branch 'main' into tpu-n
WoosukKwon Jul 10, 2024
717b3fa
Merge branch 'main' into tpu-n
WoosukKwon Jul 15, 2024
6b0c35d
Merge branch 'main' into tpu-n
WoosukKwon Jul 17, 2024
7f583ba
Merge branch 'main' into tpu-n
WoosukKwon Jul 18, 2024
106864d
Remove unused
WoosukKwon Jul 18, 2024
223661f
Minor
WoosukKwon Jul 18, 2024
5bd67bc
Merge branch 'main' into tpu-n
WoosukKwon Jul 21, 2024
ab7cccf
Fix comm error
WoosukKwon Jul 21, 2024
4e0c90a
Use custom inference_mode
WoosukKwon Jul 21, 2024
a2358ed
Remove hack in vocab embedding
WoosukKwon Jul 21, 2024
ac21351
Use patch
WoosukKwon Jul 21, 2024
ba76d9e
Update inference_mode
WoosukKwon Jul 21, 2024
452c321
use_all_gather -> use_gather
WoosukKwon Jul 21, 2024
dcb63b7
Fix patch
WoosukKwon Jul 21, 2024
825cc44
Fix typo
WoosukKwon Jul 21, 2024
f27ef99
Merge branch 'main' into tpu-n
WoosukKwon Jul 22, 2024
9730288
Remove inference_mode
WoosukKwon Jul 22, 2024
631b08b
Add no_grad
WoosukKwon Jul 23, 2024
d65a7d0
Merge branch 'main' into tpu-n
WoosukKwon Jul 23, 2024
755fe0b
Merge branch 'main' into tpu-n
WoosukKwon Jul 24, 2024
d5fadfd
Merge branch 'main' into tpu-n
WoosukKwon Jul 26, 2024
af3a259
[TPU] Support collective communications in XLA devices
WoosukKwon Jul 26, 2024
0f2abea
Use current_platform
WoosukKwon Jul 26, 2024
8ebea7e
is_xla -> is_tpu
WoosukKwon Jul 26, 2024
782b182
Define TPU communicator
WoosukKwon Jul 26, 2024
76fd300
Merge branch 'main' into tpu-n
WoosukKwon Jul 26, 2024
75f842b
Merge branch 'add-xla-comm' into tpu-n
WoosukKwon Jul 26, 2024
8087227
Fix
WoosukKwon Jul 26, 2024
f04e179
Address comments
WoosukKwon Jul 26, 2024
f493c89
Device init
WoosukKwon Jul 26, 2024
f14b085
Fix patch
WoosukKwon Jul 26, 2024
1668582
Merge branch 'add-xla-comm' into tpu-n
WoosukKwon Jul 26, 2024
a05cf0f
Merge branch 'main' into tpu-n
WoosukKwon Jul 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
triton # To avoid import errors
4 changes: 2 additions & 2 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class PallasMetadata(AttentionMetadata):

# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor]
context_lens: Optional[torch.Tensor]
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None

@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
Expand Down
10 changes: 8 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,14 @@ def _get_executor_cls(cls,
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutor
executor_class = RayTPUExecutor
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
Expand Down
313 changes: 313 additions & 0 deletions vllm/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
import asyncio
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
Union)

import vllm.envs as envs
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.executor.tpu_executor import TPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


class RayTPUExecutor(TPUExecutor):

def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}

super().__init__(*args, **kwargs)

def _init_executor(self) -> None:
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group

# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

# Create the parallel TPU workers.
self._init_workers_ray(placement_group)

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []

# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("TPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

assert self.speculative_config is None
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"

worker = ray.remote(
num_cpus=0,
resources={"TPU": 1},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker)

if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any TPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"TPU node.")

# Get the set of TPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)

node_workers = defaultdict(list)
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)

VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for _ in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)

if len(node_workers) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())

# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.

Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)

def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:

- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""

if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)

# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]

if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs

driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)

return [driver_worker_output] + ray_worker_outputs

def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)

def determine_num_available_blocks(self) -> Tuple[int, int]:
num_blocks = self._run_workers("determine_num_available_blocks", )
num_tpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_tpu_blocks, num_cpu_blocks

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)

def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)

# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)

def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return

self._driver_execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)


class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method)

async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())

# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)

async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return

await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks

async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)

async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
Loading
Loading