6
6
import torch
7
7
from torch .distributed import ProcessGroup
8
8
9
+ from vllm .config import get_current_vllm_config
10
+ from vllm .logger import init_logger
9
11
from vllm .platforms import current_platform
10
12
11
13
from .base_device_communicator import DeviceCommunicatorBase
12
14
15
+ USE_RAY = parallel_config = get_current_vllm_config (
16
+ ).parallel_config .distributed_executor_backend == "ray"
17
+
18
+ logger = init_logger (__name__ )
19
+
13
20
if current_platform .is_tpu ():
21
+ import torch_xla
14
22
import torch_xla .core .xla_model as xm
15
23
import torch_xla .runtime as xr
16
24
from torch_xla ._internal import pjrt
17
25
18
- from vllm .executor import ray_utils
26
+ if USE_RAY :
27
+ from vllm .executor import ray_utils
19
28
20
29
21
30
class TpuCommunicator (DeviceCommunicatorBase ):
@@ -33,19 +42,32 @@ def __init__(self,
33
42
global_rank = self .global_rank
34
43
global_world_size = self .global_world_size
35
44
36
- # Calculate how many TPU nodes are in the current deployment. This
37
- # is the Ray placement group if it is deployed with Ray. Default
38
- # to the number of TPU nodes in the Ray cluster. The number of TPU
39
- # nodes is computed by the total number of TPUs divided by the
40
- # number of TPU accelerators per node, to account for clusters
41
- # with both CPUs and TPUs.
42
- num_nodes = ray_utils .get_num_tpu_nodes ()
43
- num_nodes_in_pg = ray_utils .get_num_nodes_in_placement_group ()
44
- if num_nodes_in_pg > 0 :
45
- num_nodes = num_nodes_in_pg
46
-
47
- local_world_size = global_world_size // num_nodes
48
- local_rank = global_rank % local_world_size
45
+ if USE_RAY :
46
+ logger .info ("TpuCommunicator initialized with RAY" )
47
+ # Calculate how many TPU nodes are in the current deployment. This
48
+ # is the Ray placement group if it is deployed with Ray. Default
49
+ # to the number of TPU nodes in the Ray cluster. The number of TPU
50
+ # nodes is computed by the total number of TPUs divided by the
51
+ # number of TPU accelerators per node, to account for clusters
52
+ # with both CPUs and TPUs.
53
+ num_nodes = ray_utils .get_num_tpu_nodes ()
54
+ num_nodes_in_pg = ray_utils .get_num_nodes_in_placement_group ()
55
+ if num_nodes_in_pg > 0 :
56
+ num_nodes = num_nodes_in_pg
57
+
58
+ local_world_size = global_world_size // num_nodes
59
+ local_rank = global_rank % local_world_size
60
+ else :
61
+ logger .info ("TpuCommunicator initialized with MP" )
62
+ # Sanity: Verify we run on a single host
63
+ num_hosts = torch_xla .tpu .num_tpu_workers ()
64
+ assert num_hosts == 1
65
+
66
+ # Get the current number of TPUs (we have locally)
67
+ local_world_size = torch_xla .tpu .num_available_chips ()
68
+
69
+ # Get current rank
70
+ local_rank = global_rank % local_world_size
49
71
50
72
# Ensure environment variables are set for multihost deployments.
51
73
# On GKE, this is needed for libtpu and TPU driver to know which TPU
0 commit comments