We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 57c25f8 commit c17e677Copy full SHA for c17e677
vllm/model_executor/models/qwen2_5_vl.py
@@ -198,9 +198,8 @@ def forward(self, x: torch.Tensor):
198
199
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
200
"""All-gather the input tensor interleavely across model parallel group."""
201
- import torch.distributed as dist
202
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
203
- dist.all_gather(gathered_tensors, local_tensor)
+ parallel_state.get_tp_group().all_gather(gathered_tensors, local_tensor)
204
205
gathered_tensors_split = [
206
torch.split(tensor, hidden_size // tp_size, -1)
0 commit comments