1
+ import ctypes
2
+ import logging
1
3
import os
4
+ import site
2
5
import sys
3
6
import time
7
+ from enum import IntEnum , IntFlag , auto
8
+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
4
9
10
+ import numpy as np
11
+ import tensorrt as trt
12
+ import tensorrt_llm
5
13
import torch
14
+ import torch .distributed as dist
6
15
import torch .nn as nn
7
16
import torch_tensorrt
8
17
from torch .distributed ._tensor import Shard
12
21
RowwiseParallel ,
13
22
parallelize_module ,
14
23
)
24
+ from torch .fx import GraphModule , Node
25
+ from torch .fx .node import Argument , Target
26
+ from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
27
+ from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
28
+ dynamo_tensorrt_converter ,
29
+ )
30
+ from torch_tensorrt .dynamo .lowering .passes .fuse_distributed_ops import (
31
+ custom_fused_all_gather_op ,
32
+ custom_fused_reduce_scatter_op ,
33
+ )
34
+ from torch_tensorrt .dynamo .types import TRTTensor
35
+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
36
+
37
+
38
+ # This is required for env initialization since we use mpirun
39
+ def initialize (rank = 0 , world_size = 1 , port = 29500 ):
40
+ local_rank = int (
41
+ os .environ .get ("OMPI_COMM_WORLD_LOCAL_RANK" , rank % torch .cuda .device_count ())
42
+ )
43
+ world_size = int (os .environ .get ("OMPI_COMM_WORLD_SIZE" , world_size ))
44
+
45
+ # Set up environment variable to run with mpirun
46
+ os .environ ["RANK" ] = str (local_rank )
47
+ os .environ ["WORLD_SIZE" ] = str (world_size )
48
+ os .environ ["MASTER_ADDR" ] = "127.0.0.1"
49
+ os .environ ["MASTER_PORT" ] = str (port )
50
+
51
+ # Necessary to assign a device to each rank.
52
+ torch .cuda .set_device (local_rank )
53
+
54
+ # We use nccl backend
55
+ dist .init_process_group ("nccl" )
56
+
57
+ # set a manual seed for reproducibility
58
+ torch .manual_seed (1111 )
59
+
60
+ return local_rank , world_size
61
+
62
+
63
+ initialize ()
64
+ # create a device mesh based on the given world_size.
65
+ _world_size = int (os .environ ["WORLD_SIZE" ])
66
+
67
+ device_mesh = init_device_mesh (device_type = "cuda" , mesh_shape = (_world_size ,))
68
+ _rank = device_mesh .get_rank ()
69
+ device_id = _rank % torch .cuda .device_count () # Ensure each rank gets a unique device
70
+ torch .cuda .set_device (device_id )
71
+
72
+
73
+ logger = logging .getLogger ()
74
+ logger .setLevel (logging .INFO )
75
+ fh = logging .FileHandler (f"./tensor_parallel_simple_example_{ _rank } .log" , mode = "w" )
76
+ fh .setLevel (logging .INFO )
77
+ logger .addHandler (fh )
78
+
79
+
80
+ # TensorRT NCCL plugins
81
+ tensorrt_llm_lib_path = tensorrt_llm .__file__
82
+ plugin_lib_path = tensorrt_llm_lib_path + "/libs/libnvinfer_plugin_tensorrt_llm.so"
83
+ try :
84
+ ctypes .CDLL (plugin_lib_path )
85
+ logger .info (f"plugin loaded successfully" )
86
+ except OSError as e :
87
+ logger .info (f"unsuccessful load : { e } " )
88
+ trt .init_libnvinfer_plugins (None , "" )
89
+ # Iterate over all registered plugin creators
90
+ plugin_registry = trt .get_plugin_registry ()
91
+ for plugin_creator in plugin_registry .plugin_creator_list :
92
+ logger .info (
93
+ f"Plugin Name: { plugin_creator .name } , Namespace: { plugin_creator .plugin_namespace } , Version: { plugin_creator .plugin_version } "
94
+ )
95
+
96
+
97
+ # class for AllReduce
98
+ class AllReduceStrategy (IntEnum ):
99
+ """Warning: actual definition is in kernels/customAllReduceKernels.h.
100
+
101
+ They must be kept in sync.
102
+ """
103
+
104
+ NCCL = 0
105
+ ONESHOT = 1
106
+ TWOSHOT = 2
107
+ AUTO = 3
108
+
109
+
110
+ class AllReduceConfig (IntFlag ):
111
+ """Warning: actual definition is in kernels/customAllReduceKernels.h.
112
+
113
+ They must be kept in sync
114
+ """
115
+
116
+ USE_MEMCPY = auto ()
117
+ PUSH_MODE = auto ()
118
+
119
+
120
+ @dynamo_tensorrt_converter (custom_fused_all_gather_op )
121
+ def insert_nccl_gather_op (
122
+ ctx : ConversionContext ,
123
+ target : Target ,
124
+ args : Tuple [Argument , ...],
125
+ kwargs : Dict [str , Argument ],
126
+ name : str ,
127
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
128
+ plug_inputs = [args [0 ]]
129
+ allgather_plg_creator = trt .get_plugin_registry ().get_plugin_creator (
130
+ "AllGather" , "1" , "tensorrt_llm"
131
+ )
132
+ assert allgather_plg_creator is not None
133
+ world_size = dist .get_world_size ()
134
+ group = list (range (world_size ))
135
+ group = trt .PluginField (
136
+ "group" , np .array (group , dtype = np .int32 ), trt .PluginFieldType .INT32
137
+ )
138
+ p_dtype = trt .float16
139
+ pf_type = trt .PluginField (
140
+ "type_id" , np .array ([int (p_dtype )], np .int32 ), trt .PluginFieldType .INT32
141
+ )
142
+ pfc = trt .PluginFieldCollection ([group , pf_type ])
143
+ allgather = allgather_plg_creator .create_plugin ("allgather" , pfc )
144
+ layer = ctx .net .add_plugin_v2 (plug_inputs , allgather )
145
+ set_layer_name (layer , target , name )
146
+ return layer .get_output (0 )
147
+
148
+
149
+ @dynamo_tensorrt_converter (custom_fused_reduce_scatter_op )
150
+ def insert_nccl_reduce_scatter_plugin (
151
+ ctx : ConversionContext ,
152
+ target : Target ,
153
+ args : Tuple [Argument , ...],
154
+ kwargs : Dict [str , Argument ],
155
+ name : str ,
156
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
157
+ plug_inputs = [args [0 ]]
158
+ allreduce_plg_creator = trt .get_plugin_registry ().get_plugin_creator (
159
+ "ReduceScatter" , "1" , "tensorrt_llm"
160
+ )
161
+
162
+ assert allreduce_plg_creator is not None
163
+
164
+ counter = 0
165
+ strategy = AllReduceStrategy .NCCL
166
+ config = AllReduceConfig (0 )
167
+
168
+ world_size = dist .get_world_size ()
169
+ group = list (range (world_size ))
170
+ group = trt .PluginField (
171
+ "group" , np .array (group , dtype = np .int32 ), trt .PluginFieldType .INT32
172
+ )
173
+
174
+ p_dtype = trt .float16
175
+ pf_dtype = trt .PluginField (
176
+ "type_id" , np .array ([int (p_dtype )], np .int32 ), trt .PluginFieldType .INT32
177
+ )
178
+ pfc = [group , pf_dtype ]
179
+ p_strategy = trt .PluginField (
180
+ "strategy" , np .array ([int (strategy )], np .int8 ), trt .PluginFieldType .INT8
181
+ )
182
+ pfc .append (p_strategy )
183
+ p_config = trt .PluginField (
184
+ "config" , np .array ([int (config )], np .int8 ), trt .PluginFieldType .INT8
185
+ )
186
+ pfc .append (p_config )
187
+ p_counter = trt .PluginField (
188
+ "counter" , np .array ([counter ], np .int32 ), trt .PluginFieldType .INT32
189
+ )
190
+ pfc .append (p_counter )
191
+
192
+ pfc = trt .PluginFieldCollection (pfc )
193
+ ar_plug = allreduce_plg_creator .create_plugin ("allreduce" , pfc )
194
+
195
+ layer = ctx .net .add_plugin_v2 (plug_inputs , ar_plug )
196
+ set_layer_name (layer , target , name )
197
+ return layer .get_output (0 )
198
+
15
199
16
200
"""
17
201
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
@@ -36,13 +220,6 @@ def forward(self, x):
36
220
return x
37
221
38
222
39
- # create a device mesh based on the given world_size.
40
- _world_size = int (os .environ ["WORLD_SIZE" ])
41
-
42
- device_mesh = init_device_mesh (device_type = "cuda" , mesh_shape = (_world_size ,))
43
- _rank = device_mesh .get_rank ()
44
-
45
-
46
223
print (f"Starting PyTorch TP example on rank { _rank } ." )
47
224
assert (
48
225
_world_size % 2 == 0
0 commit comments