Skip to content

Commit c916bf6

Browse files
committed
using nccl ops from TRT-LLM namespace
1 parent 1820713 commit c916bf6

File tree

5 files changed

+265
-8
lines changed

5 files changed

+265
-8
lines changed

examples/distributed_inference/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ See the examples started with `data_parallel` for more details.
1414
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.
1515

1616
torchrun --nproc_per_node=2 tensor_parallel_llama2.py
17+
18+
3. Tensor parallel distributed inference using nccl ops plugin
19+
20+
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
accelerate
22
transformers
3-
diffusers
3+
diffusers
4+
site
5+
# Install tensorrt-llm without its dependencies (use the command separately). pip install tensorrt-llm --no-deps
6+
tensorrt-llm

examples/distributed_inference/tensor_parallel_simple_example.py

+184-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
import ctypes
2+
import logging
13
import os
4+
import site
25
import sys
36
import time
7+
from enum import IntEnum, IntFlag, auto
8+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
49

10+
import numpy as np
11+
import tensorrt as trt
12+
import tensorrt_llm
513
import torch
14+
import torch.distributed as dist
615
import torch.nn as nn
716
import torch_tensorrt
817
from torch.distributed._tensor import Shard
@@ -12,6 +21,181 @@
1221
RowwiseParallel,
1322
parallelize_module,
1423
)
24+
from torch.fx import GraphModule, Node
25+
from torch.fx.node import Argument, Target
26+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
27+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
28+
dynamo_tensorrt_converter,
29+
)
30+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
31+
custom_fused_all_gather_op,
32+
custom_fused_reduce_scatter_op,
33+
)
34+
from torch_tensorrt.dynamo.types import TRTTensor
35+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
36+
37+
38+
# This is required for env initialization since we use mpirun
39+
def initialize(rank=0, world_size=1, port=29500):
40+
local_rank = int(
41+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
42+
)
43+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
44+
45+
# Set up environment variable to run with mpirun
46+
os.environ["RANK"] = str(local_rank)
47+
os.environ["WORLD_SIZE"] = str(world_size)
48+
os.environ["MASTER_ADDR"] = "127.0.0.1"
49+
os.environ["MASTER_PORT"] = str(port)
50+
51+
# Necessary to assign a device to each rank.
52+
torch.cuda.set_device(local_rank)
53+
54+
# We use nccl backend
55+
dist.init_process_group("nccl")
56+
57+
# set a manual seed for reproducibility
58+
torch.manual_seed(1111)
59+
60+
return local_rank, world_size
61+
62+
63+
initialize()
64+
# create a device mesh based on the given world_size.
65+
_world_size = int(os.environ["WORLD_SIZE"])
66+
67+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
68+
_rank = device_mesh.get_rank()
69+
device_id = _rank % torch.cuda.device_count() # Ensure each rank gets a unique device
70+
torch.cuda.set_device(device_id)
71+
72+
73+
logger = logging.getLogger()
74+
logger.setLevel(logging.INFO)
75+
fh = logging.FileHandler(f"./tensor_parallel_simple_example_{_rank}.log", mode="w")
76+
fh.setLevel(logging.INFO)
77+
logger.addHandler(fh)
78+
79+
80+
# TensorRT NCCL plugins
81+
tensorrt_llm_lib_path = tensorrt_llm.__file__
82+
plugin_lib_path = tensorrt_llm_lib_path + "/libs/libnvinfer_plugin_tensorrt_llm.so"
83+
try:
84+
ctypes.CDLL(plugin_lib_path)
85+
logger.info(f"plugin loaded successfully")
86+
except OSError as e:
87+
logger.info(f"unsuccessful load : {e}")
88+
trt.init_libnvinfer_plugins(None, "")
89+
# Iterate over all registered plugin creators
90+
plugin_registry = trt.get_plugin_registry()
91+
for plugin_creator in plugin_registry.plugin_creator_list:
92+
logger.info(
93+
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
94+
)
95+
96+
97+
# class for AllReduce
98+
class AllReduceStrategy(IntEnum):
99+
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
100+
101+
They must be kept in sync.
102+
"""
103+
104+
NCCL = 0
105+
ONESHOT = 1
106+
TWOSHOT = 2
107+
AUTO = 3
108+
109+
110+
class AllReduceConfig(IntFlag):
111+
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
112+
113+
They must be kept in sync
114+
"""
115+
116+
USE_MEMCPY = auto()
117+
PUSH_MODE = auto()
118+
119+
120+
@dynamo_tensorrt_converter(custom_fused_all_gather_op)
121+
def insert_nccl_gather_op(
122+
ctx: ConversionContext,
123+
target: Target,
124+
args: Tuple[Argument, ...],
125+
kwargs: Dict[str, Argument],
126+
name: str,
127+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
128+
plug_inputs = [args[0]]
129+
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
130+
"AllGather", "1", "tensorrt_llm"
131+
)
132+
assert allgather_plg_creator is not None
133+
world_size = dist.get_world_size()
134+
group = list(range(world_size))
135+
group = trt.PluginField(
136+
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
137+
)
138+
p_dtype = trt.float16
139+
pf_type = trt.PluginField(
140+
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
141+
)
142+
pfc = trt.PluginFieldCollection([group, pf_type])
143+
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
144+
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
145+
set_layer_name(layer, target, name)
146+
return layer.get_output(0)
147+
148+
149+
@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op)
150+
def insert_nccl_reduce_scatter_plugin(
151+
ctx: ConversionContext,
152+
target: Target,
153+
args: Tuple[Argument, ...],
154+
kwargs: Dict[str, Argument],
155+
name: str,
156+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
157+
plug_inputs = [args[0]]
158+
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
159+
"ReduceScatter", "1", "tensorrt_llm"
160+
)
161+
162+
assert allreduce_plg_creator is not None
163+
164+
counter = 0
165+
strategy = AllReduceStrategy.NCCL
166+
config = AllReduceConfig(0)
167+
168+
world_size = dist.get_world_size()
169+
group = list(range(world_size))
170+
group = trt.PluginField(
171+
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
172+
)
173+
174+
p_dtype = trt.float16
175+
pf_dtype = trt.PluginField(
176+
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
177+
)
178+
pfc = [group, pf_dtype]
179+
p_strategy = trt.PluginField(
180+
"strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8
181+
)
182+
pfc.append(p_strategy)
183+
p_config = trt.PluginField(
184+
"config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8
185+
)
186+
pfc.append(p_config)
187+
p_counter = trt.PluginField(
188+
"counter", np.array([counter], np.int32), trt.PluginFieldType.INT32
189+
)
190+
pfc.append(p_counter)
191+
192+
pfc = trt.PluginFieldCollection(pfc)
193+
ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc)
194+
195+
layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug)
196+
set_layer_name(layer, target, name)
197+
return layer.get_output(0)
198+
15199

16200
"""
17201
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
@@ -36,13 +220,6 @@ def forward(self, x):
36220
return x
37221

38222

39-
# create a device mesh based on the given world_size.
40-
_world_size = int(os.environ["WORLD_SIZE"])
41-
42-
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
43-
_rank = device_mesh.get_rank()
44-
45-
46223
print(f"Starting PyTorch TP example on rank {_rank}.")
47224
assert (
48225
_world_size % 2 == 0

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .accumulate_fp32_matmul import accumulate_fp32_matmul
77
from .constant_folding import constant_fold
8+
from .fuse_distributed_ops import fuse_distributed_ops
89
from .fuse_prims_broadcast import fuse_prims_broadcast
910
from .lower_linear import lower_linear
1011
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
@@ -25,6 +26,7 @@
2526
lower_scaled_dot_product_attention,
2627
lower_linear,
2728
fuse_prims_broadcast,
29+
fuse_distributed_ops,
2830
replace_max_pool_with_indices,
2931
replace_full_like_with_full,
3032
view_to_reshape,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
6+
# dead-code elimination, linting, and recompilation for graph, in-place
7+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
8+
clean_up_graph_after_modifications,
9+
)
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def custom_fused_all_gather_op(args0, args1, args2):
15+
return torch.ops._c10d_functional.wait_tensor.default(
16+
torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2)
17+
)
18+
19+
20+
def custom_fused_reduce_scatter_op(args0, args1, args2, args3):
21+
return torch.ops._c10d_functional.wait_tensor.default(
22+
torch.ops._c10d_functional.reduce_scatter_tensor.default(
23+
args0, args1, args2, args3
24+
)
25+
)
26+
27+
28+
def fuse_distributed_ops(
29+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
30+
) -> torch.fx.GraphModule:
31+
modified_graph = False
32+
for node in gm.graph.nodes:
33+
if (
34+
node.target
35+
in (
36+
torch.ops._c10d_functional.all_gather_into_tensor.default,
37+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
38+
)
39+
and len(node.users) == 1
40+
and list(node.users)[0].target
41+
== torch.ops._c10d_functional.wait_tensor.default
42+
):
43+
wait_tensor_node = list(node.users)[0]
44+
fused_op = None
45+
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
46+
fused_op = custom_fused_all_gather_op
47+
fused_op_args = (node.args[0], node.args[1], node.args[2])
48+
else:
49+
fused_op = custom_fused_reduce_scatter_op
50+
fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3])
51+
with gm.graph.inserting_after(wait_tensor_node):
52+
fused_node = gm.graph.create_node(
53+
op="call_function",
54+
target=fused_op, # Define your custom fused function
55+
args=fused_op_args,
56+
)
57+
58+
wait_tensor_node.replace_all_uses_with(fused_node)
59+
fused_node.meta.update(node.meta)
60+
modified_graph = True
61+
gm.graph.erase_node(wait_tensor_node)
62+
gm.graph.erase_node(node)
63+
64+
# If graph was modified, clean it up
65+
if modified_graph:
66+
gm = clean_up_graph_after_modifications(gm)
67+
logger.debug(
68+
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
69+
)
70+
71+
return gm

0 commit comments

Comments
 (0)