@@ -176,3 +176,85 @@ def forward(self, x):
176
176
DECIMALS_OF_AGREEMENT ,
177
177
msg = f"Torch outputs and TRT outputs don't match close enough." ,
178
178
)
179
+
180
+
181
+ class TestBF16Support (TestCase ):
182
+ @unittest .skipIf (
183
+ not torch_tensorrt .ENABLED_FEATURES .torch_tensorrt_runtime ,
184
+ "Torch-TensorRT Runtime is not available" ,
185
+ )
186
+ def test_bf16_cpp (self ):
187
+ class MyModule (torch .nn .Module ):
188
+ def __init__ (self ):
189
+ super ().__init__ ()
190
+ self .conv = torch .nn .Conv2d (3 , 16 , 3 , stride = 1 , bias = True )
191
+ self .relu = torch .nn .ReLU ()
192
+
193
+ def forward (self , x ):
194
+ out = self .conv (x )
195
+ out = self .relu (out )
196
+ return out
197
+
198
+ in_tensor = torch .randn ((1 , 3 , 224 , 224 ), device = "cuda" , dtype = torch .bfloat16 )
199
+ mod = MyModule ().to (torch .device ("cuda" )).to (torch .bfloat16 )
200
+
201
+ exp_mod = torch .export .export (mod , (in_tensor ,))
202
+ trt_mod = torch_tensorrt .dynamo .compile (
203
+ exp_mod ,
204
+ inputs = [in_tensor ],
205
+ pass_through_build_failures = True ,
206
+ enabled_precisions = {torch .float , torch .bfloat16 , torch .half },
207
+ min_block_size = 1 ,
208
+ use_python_runtime = False ,
209
+ )
210
+
211
+ torch_model_results = mod (in_tensor )
212
+ optimized_model_results = trt_mod (in_tensor )
213
+
214
+ max_diff = float (
215
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
216
+ )
217
+ self .assertAlmostEqual (
218
+ max_diff ,
219
+ 0 ,
220
+ DECIMALS_OF_AGREEMENT ,
221
+ msg = f"Torch outputs and TRT outputs don't match close enough." ,
222
+ )
223
+
224
+ def test_bf16_py (self ):
225
+ class MyModule (torch .nn .Module ):
226
+ def __init__ (self ):
227
+ super ().__init__ ()
228
+ self .conv = torch .nn .Conv2d (3 , 16 , 3 , stride = 1 , bias = True )
229
+ self .relu = torch .nn .ReLU ()
230
+
231
+ def forward (self , x ):
232
+ out = self .conv (x )
233
+ out = self .relu (out )
234
+ return out
235
+
236
+ in_tensor = torch .randn ((1 , 3 , 224 , 224 ), device = "cuda" , dtype = torch .bfloat16 )
237
+ mod = MyModule ().to (torch .device ("cuda" )).to (torch .bfloat16 )
238
+
239
+ exp_mod = torch .export .export (mod , (in_tensor ,))
240
+ trt_mod = torch_tensorrt .dynamo .compile (
241
+ exp_mod ,
242
+ inputs = [in_tensor ],
243
+ pass_through_build_failures = True ,
244
+ enabled_precisions = {torch .float , torch .bfloat16 , torch .half },
245
+ min_block_size = 1 ,
246
+ use_python_runtime = True ,
247
+ )
248
+
249
+ torch_model_results = mod (in_tensor )
250
+ optimized_model_results = trt_mod (in_tensor )
251
+
252
+ max_diff = float (
253
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
254
+ )
255
+ self .assertAlmostEqual (
256
+ max_diff ,
257
+ 0 ,
258
+ DECIMALS_OF_AGREEMENT ,
259
+ msg = f"Torch outputs and TRT outputs don't match close enough." ,
260
+ )
0 commit comments