Skip to content

Commit 28ad766

Browse files
committed
[V1] TPU - Add tensor parallel support via Ray
Signed-off-by: Alexander Matveev <[email protected]>
1 parent 4167252 commit 28ad766

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

tests/entrypoints/llm/test_accuracy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def run_test(more_args=None):
4242
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
4343

4444

45+
# TODO: [AlexM] Fix it with new CI/CD tests
46+
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
47+
48+
4549
@pytest.mark.skipif(not current_platform.is_cuda()
4650
and not current_platform.is_tpu(),
4751
reason="V1 is currently only supported on CUDA and TPU")
@@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
5660
# Limit compilation time for TPU V1
5761
more_args = "max_num_seqs=64"
5862

63+
# Add TP test (if provided)
64+
if TPU_TP_TEST_STR:
65+
more_args += ",{}".format(TPU_TP_TEST_STR)
66+
5967
run_test(more_args)
6068

6169

vllm/executor/ray_distributed_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
5454
def _init_executor(self) -> None:
5555
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
5656
if envs.VLLM_USE_V1:
57-
# v1 always uses the compiled DAG and SPMD worker.
57+
# V1 uses SPMD worker and compiled DAG
5858
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
5959
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
60+
61+
# For TPU, avoid compiling NVIDIA's NCCL
62+
if current_platform.is_tpu():
63+
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
64+
6065
# If the env var is set, it uses the Ray's compiled DAG API
6166
# which optimizes the control plane overhead.
6267
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

vllm/executor/ray_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
77

88
import msgspec
9+
import torch
910

1011
import vllm.platforms
1112
from vllm.config import ParallelConfig
1213
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1314
from vllm.logger import init_logger
15+
from vllm.platforms import current_platform
1416
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1517
from vllm.utils import get_ip
1618
from vllm.worker.worker_base import WorkerWrapperBase
@@ -106,10 +108,14 @@ def setup_device_if_necessary(self):
106108
# on a background thread, so we need to reset torch's current
107109
# device.
108110
# We can remove this API after it is fixed in compiled graph.
109-
import torch
110111
assert self.worker is not None, "Worker is not initialized"
111112
if not self.compiled_dag_cuda_device_set:
112-
torch.cuda.set_device(self.worker.device)
113+
if current_platform.is_tpu():
114+
# TODO: [AlexM] Verify if set_device is necessary here
115+
pass
116+
else:
117+
torch.cuda.set_device(self.worker.device)
118+
113119
self.compiled_dag_cuda_device_set = True
114120

115121
def execute_model_ray(

vllm/v1/worker/tpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.logger import init_logger
1919
from vllm.model_executor.model_loader import get_model
2020
from vllm.sampling_params import SamplingType
21+
from vllm.sequence import IntermediateTensors
2122
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
2223
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
2324
NUM_QUERIES_PER_BLOCK,
@@ -430,6 +431,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
430431
def execute_model(
431432
self,
432433
scheduler_output: "SchedulerOutput",
434+
intermediate_tensors: Optional[IntermediateTensors] = None,
433435
) -> ModelRunnerOutput:
434436
# Update cached state
435437
self._update_states(scheduler_output)

0 commit comments

Comments
 (0)