Skip to content

Commit 01649f9

Browse files
committed
[V1] TPU - Add tensor parallel support via Ray
Signed-off-by: Alexander Matveev <[email protected]>
1 parent a02c86b commit 01649f9

File tree

5 files changed

+28
-7
lines changed

5 files changed

+28
-7
lines changed

requirements-tpu.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ ray[default]
1717
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1818
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
1919
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
20-
torch==2.6.0.dev20241216+cpu
21-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
22-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
20+
torch==2.7.0.dev20250212+cpu
21+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
22+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

tests/entrypoints/llm/test_accuracy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def run_test(more_args=None):
4242
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
4343

4444

45+
# TODO: [AlexM] Fix it with new CI/CD tests
46+
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
47+
48+
4549
@pytest.mark.skipif(not current_platform.is_cuda()
4650
and not current_platform.is_tpu(),
4751
reason="V1 is currently only supported on CUDA and TPU")
@@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
5660
# Limit compilation time for TPU V1
5761
more_args = "max_num_seqs=64"
5862

63+
# Add TP test (if provided)
64+
if TPU_TP_TEST_STR:
65+
more_args += ",{}".format(TPU_TP_TEST_STR)
66+
5967
run_test(more_args)
6068

6169

vllm/executor/ray_distributed_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
5454
def _init_executor(self) -> None:
5555
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
5656
if envs.VLLM_USE_V1:
57-
# v1 always uses the compiled DAG and SPMD worker.
57+
# V1 uses SPMD worker and compiled DAG
5858
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
5959
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
60+
61+
# For TPU, avoid compiling NVIDIA's NCCL
62+
if current_platform.is_tpu():
63+
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
64+
6065
# If the env var is set, it uses the Ray's compiled DAG API
6166
# which optimizes the control plane overhead.
6267
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

vllm/executor/ray_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
77

88
import msgspec
9+
import torch
910

1011
import vllm.platforms
1112
from vllm.config import ParallelConfig
1213
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1314
from vllm.logger import init_logger
15+
from vllm.platforms import current_platform
1416
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1517
from vllm.utils import get_ip
1618
from vllm.worker.worker_base import WorkerWrapperBase
@@ -106,10 +108,14 @@ def setup_device_if_necessary(self):
106108
# on a background thread, so we need to reset torch's current
107109
# device.
108110
# We can remove this API after it is fixed in compiled graph.
109-
import torch
110111
assert self.worker is not None, "Worker is not initialized"
111112
if not self.compiled_dag_cuda_device_set:
112-
torch.cuda.set_device(self.worker.device)
113+
if current_platform.is_tpu():
114+
# TODO: [AlexM] Verify if set_device is necessary here
115+
pass
116+
else:
117+
torch.cuda.set_device(self.worker.device)
118+
113119
self.compiled_dag_cuda_device_set = True
114120

115121
def execute_model(

vllm/v1/worker/tpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.logger import init_logger
2222
from vllm.model_executor.model_loader import get_model
2323
from vllm.sampling_params import SamplingType
24+
from vllm.sequence import IntermediateTensors
2425
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
2526
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
2627
PallasMetadata)
@@ -583,6 +584,7 @@ def _prepare_decode(
583584
def execute_model(
584585
self,
585586
scheduler_output: "SchedulerOutput",
587+
intermediate_tensors: Optional[IntermediateTensors] = None,
586588
) -> ModelRunnerOutput:
587589
# Update cached state
588590
self._update_states(scheduler_output)

0 commit comments

Comments
 (0)