Skip to content

Commit b1fae21

Browse files
committed
bf16 support
1 parent ca59597 commit b1fae21

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

Diff for: py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
6+
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
89
from torch_tensorrt import _enums
@@ -15,11 +16,10 @@
1516
get_trt_tensor,
1617
has_dynamic_shape,
1718
set_layer_name,
19+
to_torch,
1820
)
1921
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
2022

21-
import tensorrt as trt
22-
2323

2424
def get_python_op_from_trt_elementwise_op(
2525
trt_op: TRTElementWiseOp,
@@ -125,10 +125,9 @@ def convert_binary_elementwise(
125125
# dtype but we don't have a way to detect whether it makes sense for the
126126
# scalar to be float or half. Hence we go with the lhs dtype.
127127
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
128-
rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype))
128+
rhs_val = to_torch(rhs_val, dtype=lhs_dtype)
129129
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
130-
lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype))
131-
130+
lhs_val = to_torch(lhs_val, dtype=rhs_dtype)
132131
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
133132
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)
134133

Diff for: tests/py/dynamo/conversion/test_binary_ops_aten.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import torch
44
import torch.nn as nn
5+
from harness import DispatchTestCase
56
from parameterized import parameterized
67
from torch.testing._internal.common_utils import run_tests
78
from torch_tensorrt import Input
89

9-
from .harness import DispatchTestCase
10-
1110
NEED_TEST_BOTH_CONSTANTS_CASE = True
1211

1312
elementwise_ops = [
@@ -228,6 +227,28 @@ def forward(self, x, y):
228227
]
229228
self.run_test_with_dynamic_shape(Op(), input_specs)
230229

230+
@parameterized.expand(
231+
[
232+
(f"bf16_{op[0].__name__}_one_constant", op[0])
233+
for op in elementwise_ops
234+
if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"]
235+
]
236+
)
237+
def test_elementwise_ops_bf16(self, _, orig_op):
238+
class TestModule(nn.Module):
239+
def __init__(self, orig_op):
240+
super().__init__()
241+
self.constant = torch.randn(1)
242+
self.orig_op = orig_op
243+
244+
def forward(self, x):
245+
x = self.orig_op(x, self.constant)
246+
return self.orig_op(x, -2)
247+
248+
m = TestModule(orig_op)
249+
inputs = [torch.randn(2, 2, dtype=torch.bfloat16)]
250+
self.run_test(m, inputs)
251+
231252

232253
if __name__ == "__main__":
233254
run_tests()

0 commit comments

Comments
 (0)