Skip to content

Commit cfbca8a

Browse files
authored
[V1] TPU - Tensor parallel MP support (#15059)
1 parent 0fe5609 commit cfbca8a

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def __post_init__(self) -> None:
14731473
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
14741474
logger.info("Disabling V1 multiprocessing for external launcher.")
14751475

1476-
ray_only_devices = ["tpu"]
1476+
ray_only_devices: list[str] = []
14771477
from vllm.platforms import current_platform
14781478
if (current_platform.device_type in ray_only_devices
14791479
and self.world_size > 1):

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,25 @@
66
import torch
77
from torch.distributed import ProcessGroup
88

9+
from vllm.config import get_current_vllm_config
10+
from vllm.logger import init_logger
911
from vllm.platforms import current_platform
1012

1113
from .base_device_communicator import DeviceCommunicatorBase
1214

15+
USE_RAY = parallel_config = get_current_vllm_config(
16+
).parallel_config.distributed_executor_backend == "ray"
17+
18+
logger = init_logger(__name__)
19+
1320
if current_platform.is_tpu():
21+
import torch_xla
1422
import torch_xla.core.xla_model as xm
1523
import torch_xla.runtime as xr
1624
from torch_xla._internal import pjrt
1725

18-
from vllm.executor import ray_utils
26+
if USE_RAY:
27+
from vllm.executor import ray_utils
1928

2029

2130
class TpuCommunicator(DeviceCommunicatorBase):
@@ -33,19 +42,32 @@ def __init__(self,
3342
global_rank = self.global_rank
3443
global_world_size = self.global_world_size
3544

36-
# Calculate how many TPU nodes are in the current deployment. This
37-
# is the Ray placement group if it is deployed with Ray. Default
38-
# to the number of TPU nodes in the Ray cluster. The number of TPU
39-
# nodes is computed by the total number of TPUs divided by the
40-
# number of TPU accelerators per node, to account for clusters
41-
# with both CPUs and TPUs.
42-
num_nodes = ray_utils.get_num_tpu_nodes()
43-
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
44-
if num_nodes_in_pg > 0:
45-
num_nodes = num_nodes_in_pg
46-
47-
local_world_size = global_world_size // num_nodes
48-
local_rank = global_rank % local_world_size
45+
if USE_RAY:
46+
logger.info("TpuCommunicator initialized with RAY")
47+
# Calculate how many TPU nodes are in the current deployment. This
48+
# is the Ray placement group if it is deployed with Ray. Default
49+
# to the number of TPU nodes in the Ray cluster. The number of TPU
50+
# nodes is computed by the total number of TPUs divided by the
51+
# number of TPU accelerators per node, to account for clusters
52+
# with both CPUs and TPUs.
53+
num_nodes = ray_utils.get_num_tpu_nodes()
54+
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
55+
if num_nodes_in_pg > 0:
56+
num_nodes = num_nodes_in_pg
57+
58+
local_world_size = global_world_size // num_nodes
59+
local_rank = global_rank % local_world_size
60+
else:
61+
logger.info("TpuCommunicator initialized with MP")
62+
# Sanity: Verify we run on a single host
63+
num_hosts = torch_xla.tpu.num_tpu_workers()
64+
assert num_hosts == 1
65+
66+
# Get the current number of TPUs (we have locally)
67+
local_world_size = torch_xla.tpu.num_available_chips()
68+
69+
# Get current rank
70+
local_rank = global_rank % local_world_size
4971

5072
# Ensure environment variables are set for multihost deployments.
5173
# On GKE, this is needed for libtpu and TPU driver to know which TPU

0 commit comments

Comments
 (0)