Skip to content

Commit 0861418

Browse files
committed
fix: Refactor data type handling in FX
1 parent e1555bc commit 0861418

File tree

8 files changed

+174
-112
lines changed

8 files changed

+174
-112
lines changed

py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def forward(self, x):
5454
0,
5555
msg=f"MulInt TRT outputs don't match with the original model.",
5656
)
57+
torch._dynamo.reset()
5758

5859
def test_lowering_add_float(self):
5960
class AddFloat(torch.nn.Module):
@@ -106,6 +107,8 @@ def forward(self, x):
106107
msg=f"AddFloat TRT outputs don't match with the original model.",
107108
)
108109

110+
torch._dynamo.reset()
111+
109112

110113
if __name__ == "__main__":
111114
run_tests()

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
1717
from .input_tensor_spec import InputTensorSpec
1818
from torch_tensorrt.fx.observer import Observer
19-
from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
19+
from torch_tensorrt.fx.utils import (
20+
get_dynamic_dims,
21+
LowerPrecision,
22+
unified_dtype_converter,
23+
Frameworks,
24+
)
2025

2126
_LOGGER: logging.Logger = logging.getLogger(__name__)
2227

@@ -305,7 +310,9 @@ def placeholder(self, target, args, kwargs):
305310
self.optimization_profiles[i].set_shape(target, *shape_range)
306311

307312
return self.network.add_input(
308-
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
313+
name=target,
314+
shape=tuple(shape),
315+
dtype=unified_dtype_converter(dtype, Frameworks.TRT),
309316
)
310317

311318
def call_module(self, target, args, kwargs):

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
21+
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
2222

2323
from .converter_utils import * # noqa: F403
2424
from torch_tensorrt.fx.passes.lower_basic_pass import (
@@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
400400
)
401401

402402
# cast value to TRTensor
403-
dt = torch_dtype_from_trt(input_val.dtype)
403+
dt = unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
404404
value = 0 if value == None else value
405405
value_const = get_trt_tensor(
406406
network, torch.tensor([value], dtype=dt), f"{name}_value"
@@ -1561,7 +1561,7 @@ def acc_ops_to_dtype(
15611561
input_t = get_trt_tensor(network, input_val, f"{name}_input_t")
15621562
if input_dtype:
15631563
if isinstance(input_dtype, torch.dtype):
1564-
input_dtype = torch_dtype_to_trt(input_dtype)
1564+
input_dtype = unified_dtype_converter(input_dtype, Frameworks.TRT)
15651565
input_t = type_cast(network, target, f"{name}_input", input_t, input_dtype)
15661566
return input_t
15671567

@@ -1822,7 +1822,7 @@ def acc_ops_logical_xor(
18221822
# f"isinf received input {input_t} that is not part "
18231823
# "of the TensorRT region!"
18241824
# )
1825-
# tdtype = torch_dtype_from_trt(input_t.dtype)
1825+
# tdtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18261826

18271827
# inf_t = torch.ones(tuple(input_t.shape))
18281828
# inf_t = inf_t * float("inf")
@@ -1860,7 +1860,7 @@ def acc_ops_any(
18601860

18611861
if input_t.dtype in (trt.float32, trt.float16, trt.int32):
18621862
comp_t = torch.zeros(tuple([*input_t.shape])).to(
1863-
torch_dtype_from_trt(input_t.dtype)
1863+
unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18641864
)
18651865
comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
18661866
kwargs_new = {"input": input_t, "other": comp_t}
@@ -2749,7 +2749,7 @@ def acc_ops_masked_fill_tensor(
27492749
if type(value_t) is torch.Tensor:
27502750
value_t = value_t.cpu().numpy()
27512751
# cast to input type
2752-
input_dtype = torch_dtype_from_trt(input_t.dtype)
2752+
input_dtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
27532753
value_t = (torch.ones(shape) * value_t).to(input_dtype)
27542754
input_val = get_trt_tensor(network, input_t, f"{name}_input")
27552755
value_val = get_trt_tensor(network, value_t, f"{name}_input")
@@ -2883,7 +2883,11 @@ def add_clamp(network, input, val, op, name):
28832883
# clamping scalar
28842884
acc_ops_clamp_trt = get_trt_tensor(
28852885
network,
2886-
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
2886+
squeeze_left(
2887+
torch.tensor(
2888+
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
2889+
)
2890+
),
28872891
f"{name}_clamp_{val}",
28882892
)
28892893
else:
@@ -2892,7 +2896,8 @@ def add_clamp(network, input, val, op, name):
28922896
(
28932897
val
28942898
* torch.ones(
2895-
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
2899+
acc_ops_clamp_shape,
2900+
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
28962901
)
28972902
)
28982903
.cpu()
@@ -3538,7 +3543,9 @@ def acc_ops_cumsum(
35383543
iterator = loop.add_iterator(input_val, dim, False)
35393544
data = iterator.get_output(0)
35403545
new_dims = tuple(data.shape)
3541-
zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
3546+
zero_tensor = torch.zeros(
3547+
new_dims, dtype=unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
3548+
)
35423549
zero_tensor = network.add_constant(
35433550
zero_tensor.shape, to_numpy(zero_tensor)
35443551
).get_output(0)
@@ -3689,7 +3696,7 @@ def acc_ops_new_ones(
36893696
dtype_val = kwargs.get("dtype")
36903697
if dtype_val is None:
36913698
dtype_val = input_val.dtype
3692-
dtype_val = torch_dtype_from_trt(dtype_val)
3699+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
36933700

36943701
device_val = kwargs.get("device")
36953702
assert (
@@ -3713,7 +3720,7 @@ def acc_ops_new_empty(
37133720
dtype_val = kwargs.get("dtype")
37143721
if dtype_val is None:
37153722
dtype_val = input_val.dtype
3716-
dtype_val = torch_dtype_from_trt(dtype_val)
3723+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
37173724

37183725
device_val = kwargs.get("device")
37193726
assert (

py/torch_tensorrt/fx/converters/aten_ops_converters.py

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
22-
2321
from .converter_utils import * # noqa: F403
2422
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2523
from torch_tensorrt.fx.converters.impl import activation

0 commit comments

Comments
 (0)