Skip to content

Commit a2494a7

Browse files
committed
feat(//py/torch_tensorrt/dynamo): Support for BF16
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 39b6818 commit a2494a7

File tree

7 files changed

+110
-11
lines changed

7 files changed

+110
-11
lines changed

core/util/trt_util.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
295295
{at::kLong, nvinfer1::DataType::kINT64},
296296
{at::kChar, nvinfer1::DataType::kINT8},
297297
{at::kByte, nvinfer1::DataType::kINT8},
298-
{at::kBool, nvinfer1::DataType::kBOOL}};
298+
{at::kBool, nvinfer1::DataType::kBOOL},
299+
{at::kBFloat16, nvinfer1::DataType::kBF16}};
299300
return at_trt_type_map;
300301
}
301302

@@ -307,7 +308,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
307308
{nvinfer1::DataType::kINT64, at::kLong},
308309
{nvinfer1::DataType::kINT8, at::kChar},
309310
{nvinfer1::DataType::kBOOL, at::kBool},
310-
};
311+
{nvinfer1::DataType::kBF16, at::kBFloat16}};
311312
return trt_at_type_map;
312313
}
313314
} // namespace

core/util/trt_util.h

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
5555
return stream << "Int32";
5656
case nvinfer1::DataType::kINT64:
5757
return stream << "Int64";
58+
case nvinfer1::DataType::kBF16:
59+
return stream << "BFloat16";
5860
case nvinfer1::DataType::kBOOL:
5961
return stream << "Bool";
6062
default:

py/torch_tensorrt/_enums.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ class dtype(Enum):
2424
f32 = auto()
2525
f64 = auto()
2626
b = auto()
27-
# TODO: Enable FP8 and BF16
27+
bf16 = auto()
28+
# TODO: Enable FP8
2829
# f8 = auto()
29-
# bf16 = auto()
3030

3131
uint8 = u8
3232
int8 = i8
@@ -52,8 +52,7 @@ class dtype(Enum):
5252
# float8 = f8
5353
# fp8 = f8
5454

55-
# TODO: Enable when BF16 is enabled
56-
# bfloat16 = bf16
55+
bfloat16 = bf16
5756

5857
@staticmethod
5958
def _is_np_obj(t: Any) -> bool:
@@ -88,14 +87,16 @@ def _from(
8887
return dtype.f64
8988
elif t == torch.bool:
9089
return dtype.b
90+
elif t == torch.bfloat16:
91+
return dtype.bf16
9192
elif use_default:
9293
logging.warning(
9394
f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float"
9495
)
9596
return dtype.float
9697
else:
9798
raise TypeError(
98-
f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}"
99+
f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}"
99100
)
100101
elif isinstance(t, trt.DataType):
101102
if t == trt.uint8:
@@ -112,9 +113,11 @@ def _from(
112113
return dtype.f32
113114
elif t == trt.bool:
114115
return dtype.b
116+
elif t == trt.bf16:
117+
return dtype.bf16
115118
else:
116119
raise TypeError(
117-
f"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: {t}"
120+
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
118121
)
119122

120123
elif dtype._is_np_obj(t):
@@ -141,7 +144,7 @@ def _from(
141144
return dtype.float
142145
else:
143146
raise TypeError(
144-
"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: "
147+
"Provided an unsupported data type as an input data type (support: bool, int, long, half, float, bfloat16), got: "
145148
+ str(t)
146149
)
147150

@@ -215,6 +218,8 @@ def to(
215218
return torch.double
216219
elif self == dtype.b:
217220
return torch.bool
221+
elif self == dtype.bf16:
222+
return torch.bfloat16
218223
elif use_default:
219224
logging.warning(
220225
f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float"
@@ -238,6 +243,8 @@ def to(
238243
return trt.DataType.FLOAT
239244
elif self == dtype.b:
240245
return trt.DataType.BOOL
246+
elif self == dtype.bf16:
247+
return trt.DataType.BF16
241248
elif use_default:
242249
return trt.DataType.FLOAT
243250
else:

py/torch_tensorrt/dynamo/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29-
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
29+
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16}
3030

3131

3232
def default_device() -> Device:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+3
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def _populate_trt_builder_config(
242242
if dtype.int8 in self.compilation_settings.enabled_precisions:
243243
builder_config.set_flag(trt.BuilderFlag.INT8)
244244

245+
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
246+
builder_config.set_flag(trt.BuilderFlag.BF16)
247+
245248
if self.compilation_settings.sparse_weights:
246249
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
247250

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
55

66
import numpy as np
7-
import tensorrt as trt
87
import torch
98
import torch_tensorrt.dynamo.conversion.impl as impl
109
from torch import SymBool, SymFloat, SymInt
@@ -22,6 +21,8 @@
2221
)
2322
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
2423

24+
import tensorrt as trt
25+
2526
_LOGGER: logging.Logger = logging.getLogger(__name__)
2627

2728

@@ -545,6 +546,9 @@ def to_numpy(
545546
elif isinstance(value, torch.Tensor):
546547
if value.is_quantized:
547548
value = value.dequantize()
549+
elif value.dtype == torch.bfloat16:
550+
# TODO: Remove when numpy has a BF16 type
551+
value = value.to(torch.float)
548552

549553
output = value.cpu().detach().contiguous().numpy()
550554

tests/py/dynamo/models/test_dtype_support.py

+82
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,85 @@ def forward(self, x):
176176
DECIMALS_OF_AGREEMENT,
177177
msg=f"Torch outputs and TRT outputs don't match close enough.",
178178
)
179+
180+
181+
class TestBF16Support(TestCase):
182+
@unittest.skipIf(
183+
not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime,
184+
"Torch-TensorRT Runtime is not available",
185+
)
186+
def test_bf16_cpp(self):
187+
class MyModule(torch.nn.Module):
188+
def __init__(self):
189+
super().__init__()
190+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
191+
self.relu = torch.nn.ReLU()
192+
193+
def forward(self, x):
194+
out = self.conv(x)
195+
out = self.relu(out)
196+
return out
197+
198+
in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16)
199+
mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16)
200+
201+
exp_mod = torch.export.export(mod, (in_tensor,))
202+
trt_mod = torch_tensorrt.dynamo.compile(
203+
exp_mod,
204+
inputs=[in_tensor],
205+
pass_through_build_failures=True,
206+
enabled_precisions={torch.float, torch.bfloat16, torch.half},
207+
min_block_size=1,
208+
use_python_runtime=False,
209+
)
210+
211+
torch_model_results = mod(in_tensor)
212+
optimized_model_results = trt_mod(in_tensor)
213+
214+
max_diff = float(
215+
torch.max(torch.abs(optimized_model_results - torch_model_results))
216+
)
217+
self.assertAlmostEqual(
218+
max_diff,
219+
0,
220+
DECIMALS_OF_AGREEMENT,
221+
msg=f"Torch outputs and TRT outputs don't match close enough.",
222+
)
223+
224+
def test_bf16_py(self):
225+
class MyModule(torch.nn.Module):
226+
def __init__(self):
227+
super().__init__()
228+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
229+
self.relu = torch.nn.ReLU()
230+
231+
def forward(self, x):
232+
out = self.conv(x)
233+
out = self.relu(out)
234+
return out
235+
236+
in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16)
237+
mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16)
238+
239+
exp_mod = torch.export.export(mod, (in_tensor,))
240+
trt_mod = torch_tensorrt.dynamo.compile(
241+
exp_mod,
242+
inputs=[in_tensor],
243+
pass_through_build_failures=True,
244+
enabled_precisions={torch.float, torch.bfloat16, torch.half},
245+
min_block_size=1,
246+
use_python_runtime=True,
247+
)
248+
249+
torch_model_results = mod(in_tensor)
250+
optimized_model_results = trt_mod(in_tensor)
251+
252+
max_diff = float(
253+
torch.max(torch.abs(optimized_model_results - torch_model_results))
254+
)
255+
self.assertAlmostEqual(
256+
max_diff,
257+
0,
258+
DECIMALS_OF_AGREEMENT,
259+
msg=f"Torch outputs and TRT outputs don't match close enough.",
260+
)

0 commit comments

Comments
 (0)