10
10
ncclRedOpTypeEnum , ncclUniqueId )
11
11
from vllm .distributed .utils import StatelessProcessGroup
12
12
from vllm .logger import init_logger
13
+ from vllm .utils import current_stream
13
14
14
15
logger = init_logger (__name__ )
15
16
@@ -96,7 +97,7 @@ def __init__(
96
97
self .comm : ncclComm_t = self .nccl .ncclCommInitRank (
97
98
self .world_size , self .unique_id , self .rank )
98
99
99
- stream = torch . cuda . current_stream ()
100
+ stream = current_stream ()
100
101
# A small all_reduce for warmup.
101
102
data = torch .zeros (1 , device = device )
102
103
self .all_reduce (data )
@@ -119,7 +120,7 @@ def all_reduce(self,
119
120
out_tensor = torch .empty_like (in_tensor )
120
121
121
122
if stream is None :
122
- stream = torch . cuda . current_stream ()
123
+ stream = current_stream ()
123
124
self .nccl .ncclAllReduce (buffer_type (in_tensor .data_ptr ()),
124
125
buffer_type (out_tensor .data_ptr ()),
125
126
in_tensor .numel (),
@@ -141,7 +142,7 @@ def all_gather(self,
141
142
f"this nccl communicator is created to work on { self .device } , "
142
143
f"but the input tensor is on { input_tensor .device } " )
143
144
if stream is None :
144
- stream = torch . cuda . current_stream ()
145
+ stream = current_stream ()
145
146
self .nccl .ncclAllGather (
146
147
buffer_type (input_tensor .data_ptr ()),
147
148
buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
@@ -162,7 +163,7 @@ def reduce_scatter(self,
162
163
f"this nccl communicator is created to work on { self .device } , "
163
164
f"but the input tensor is on { input_tensor .device } " )
164
165
if stream is None :
165
- stream = torch . cuda . current_stream ()
166
+ stream = current_stream ()
166
167
self .nccl .ncclReduceScatter (
167
168
buffer_type (input_tensor .data_ptr ()),
168
169
buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
@@ -177,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
177
178
f"this nccl communicator is created to work on { self .device } , "
178
179
f"but the input tensor is on { tensor .device } " )
179
180
if stream is None :
180
- stream = torch . cuda . current_stream ()
181
+ stream = current_stream ()
181
182
self .nccl .ncclSend (buffer_type (tensor .data_ptr ()), tensor .numel (),
182
183
ncclDataTypeEnum .from_torch (tensor .dtype ), dst ,
183
184
self .comm , cudaStream_t (stream .cuda_stream ))
@@ -189,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
189
190
f"this nccl communicator is created to work on { self .device } , "
190
191
f"but the input tensor is on { tensor .device } " )
191
192
if stream is None :
192
- stream = torch . cuda . current_stream ()
193
+ stream = current_stream ()
193
194
self .nccl .ncclRecv (buffer_type (tensor .data_ptr ()), tensor .numel (),
194
195
ncclDataTypeEnum .from_torch (tensor .dtype ), src ,
195
196
self .comm , cudaStream_t (stream .cuda_stream ))
@@ -201,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
201
202
f"this nccl communicator is created to work on { self .device } , "
202
203
f"but the input tensor is on { tensor .device } " )
203
204
if stream is None :
204
- stream = torch . cuda . current_stream ()
205
+ stream = current_stream ()
205
206
if src == self .rank :
206
207
sendbuff = buffer_type (tensor .data_ptr ())
207
208
# NCCL requires the sender also to have a receive buffer
0 commit comments