Skip to content

Commit 6dcd1fc

Browse files
authored
fix: Add support for fake tensors (#1955)
1 parent c2a2f61 commit 6dcd1fc

File tree

9 files changed

+302
-127
lines changed

9 files changed

+302
-127
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

-5
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
1616
from torch_tensorrt.dynamo.backend.conversion import convert_module
1717

18-
from torch._dynamo.backends.common import fake_tensor_unsupported
19-
2018
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2119

2220

2321
logger = logging.getLogger(__name__)
2422

2523

2624
@td.register_backend(name="torch_tensorrt")
27-
@fake_tensor_unsupported
2825
def torch_tensorrt_backend(
2926
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
3027
):
@@ -34,7 +31,6 @@ def torch_tensorrt_backend(
3431

3532

3633
@td.register_backend(name="aot_torch_tensorrt_aten")
37-
@fake_tensor_unsupported
3834
def aot_torch_tensorrt_aten_backend(
3935
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
4036
):
@@ -54,7 +50,6 @@ def aot_torch_tensorrt_aten_backend(
5450
)
5551

5652

57-
@fake_tensor_unsupported
5853
def _pretraced_backend(
5954
gm: torch.fx.GraphModule,
6055
sample_inputs: Sequence[torch.Tensor],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from utils import lower_graph_testing
2+
from torch.testing._internal.common_utils import run_tests, TestCase
3+
import torch
4+
from torch_tensorrt.dynamo import compile
5+
6+
7+
class TestFakeTensors(TestCase):
8+
def test_lowering_mul_int(self):
9+
class MulInt(torch.nn.Module):
10+
def forward(self, x):
11+
return x * 7
12+
13+
# Operations expected to be included in the traced graph after decompositions
14+
expected_ops = {
15+
torch.ops.aten.mul.Tensor,
16+
}
17+
18+
inputs = [
19+
torch.rand(
20+
3,
21+
5,
22+
7,
23+
).cuda(),
24+
]
25+
26+
fx_graph = torch.fx.symbolic_trace(MulInt())
27+
_, expected_ops_unseen = lower_graph_testing(
28+
fx_graph,
29+
inputs,
30+
expected_ops=expected_ops,
31+
min_block_size=1,
32+
)
33+
34+
self.assertEquals(
35+
len(expected_ops_unseen),
36+
0,
37+
f"The following expected ops were not encountered: {expected_ops_unseen}",
38+
)
39+
40+
torch._dynamo.reset()
41+
42+
# Validate that the results between Torch and Torch-TRT are similar
43+
optimized_model = compile(
44+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
45+
)
46+
optimized_model_results = optimized_model(*inputs).detach().cpu()
47+
torch_model_results = fx_graph(*inputs).detach().cpu()
48+
49+
max_diff = float(
50+
torch.max(torch.abs(optimized_model_results - torch_model_results))
51+
)
52+
self.assertAlmostEqual(
53+
max_diff,
54+
0,
55+
msg=f"MulInt TRT outputs don't match with the original model.",
56+
)
57+
torch._dynamo.reset()
58+
59+
def test_lowering_add_float(self):
60+
class AddFloat(torch.nn.Module):
61+
def forward(self, x):
62+
return x + 84.0
63+
64+
# Operations expected to be included in the traced graph after decompositions
65+
expected_ops = {
66+
torch.ops.aten.add.Tensor,
67+
}
68+
69+
inputs = [
70+
torch.rand(
71+
1,
72+
5,
73+
7,
74+
9,
75+
).cuda(),
76+
]
77+
78+
fx_graph = torch.fx.symbolic_trace(AddFloat())
79+
_, expected_ops_unseen = lower_graph_testing(
80+
fx_graph,
81+
inputs,
82+
expected_ops=expected_ops,
83+
min_block_size=1,
84+
)
85+
86+
self.assertEquals(
87+
len(expected_ops_unseen),
88+
0,
89+
f"The following expected ops were not encountered: {expected_ops_unseen}",
90+
)
91+
92+
torch._dynamo.reset()
93+
94+
# Validate that the results between Torch and Torch-TRT are similar
95+
optimized_model = compile(
96+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
97+
)
98+
optimized_model_results = optimized_model(*inputs).detach().cpu()
99+
torch_model_results = fx_graph(*inputs).detach().cpu()
100+
101+
max_diff = float(
102+
torch.max(torch.abs(optimized_model_results - torch_model_results))
103+
)
104+
self.assertAlmostEqual(
105+
max_diff,
106+
0,
107+
msg=f"AddFloat TRT outputs don't match with the original model.",
108+
)
109+
110+
torch._dynamo.reset()
111+
112+
113+
if __name__ == "__main__":
114+
run_tests()

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
1818
from .input_tensor_spec import InputTensorSpec
1919
from torch_tensorrt.fx.observer import Observer
20-
from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
20+
from torch_tensorrt.fx.utils import (
21+
get_dynamic_dims,
22+
LowerPrecision,
23+
unified_dtype_converter,
24+
Frameworks,
25+
)
2126

2227
_LOGGER: logging.Logger = logging.getLogger(__name__)
2328

@@ -330,7 +335,9 @@ def placeholder(self, target, args, kwargs):
330335
self.optimization_profiles[i].set_shape(target, *shape_range)
331336

332337
return self.network.add_input(
333-
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
338+
name=target,
339+
shape=tuple(shape),
340+
dtype=unified_dtype_converter(dtype, Frameworks.TRT),
334341
)
335342

336343
def call_module(self, target, args, kwargs):

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
21+
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
2222

2323
from .converter_utils import * # noqa: F403
2424
from torch_tensorrt.fx.passes.lower_basic_pass import (
@@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
400400
)
401401

402402
# cast value to TRTensor
403-
dt = torch_dtype_from_trt(input_val.dtype)
403+
dt = unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
404404
value = 0 if value == None else value
405405
value_const = get_trt_tensor(
406406
network, torch.tensor([value], dtype=dt), f"{name}_value"
@@ -1551,7 +1551,7 @@ def acc_ops_to_dtype(
15511551
input_t = get_trt_tensor(network, input_val, f"{name}_input_t")
15521552
if input_dtype:
15531553
if isinstance(input_dtype, torch.dtype):
1554-
input_dtype = torch_dtype_to_trt(input_dtype)
1554+
input_dtype = unified_dtype_converter(input_dtype, Frameworks.TRT)
15551555
input_t = type_cast(network, target, f"{name}_input", input_t, input_dtype)
15561556
return input_t
15571557

@@ -1812,7 +1812,7 @@ def acc_ops_logical_xor(
18121812
# f"isinf received input {input_t} that is not part "
18131813
# "of the TensorRT region!"
18141814
# )
1815-
# tdtype = torch_dtype_from_trt(input_t.dtype)
1815+
# tdtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18161816

18171817
# inf_t = torch.ones(tuple(input_t.shape))
18181818
# inf_t = inf_t * float("inf")
@@ -1850,7 +1850,7 @@ def acc_ops_any(
18501850

18511851
if input_t.dtype in (trt.float32, trt.float16, trt.int32):
18521852
comp_t = torch.zeros(tuple([*input_t.shape])).to(
1853-
torch_dtype_from_trt(input_t.dtype)
1853+
unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18541854
)
18551855
comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
18561856
kwargs_new = {"input": input_t, "other": comp_t}
@@ -2739,7 +2739,7 @@ def acc_ops_masked_fill_tensor(
27392739
if type(value_t) is torch.Tensor:
27402740
value_t = value_t.cpu().numpy()
27412741
# cast to input type
2742-
input_dtype = torch_dtype_from_trt(input_t.dtype)
2742+
input_dtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
27432743
value_t = (torch.ones(shape) * value_t).to(input_dtype)
27442744
input_val = get_trt_tensor(network, input_t, f"{name}_input")
27452745
value_val = get_trt_tensor(network, value_t, f"{name}_input")
@@ -2873,7 +2873,11 @@ def add_clamp(network, input, val, op, name):
28732873
# clamping scalar
28742874
acc_ops_clamp_trt = get_trt_tensor(
28752875
network,
2876-
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
2876+
squeeze_left(
2877+
torch.tensor(
2878+
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
2879+
)
2880+
),
28772881
f"{name}_clamp_{val}",
28782882
)
28792883
else:
@@ -2882,7 +2886,8 @@ def add_clamp(network, input, val, op, name):
28822886
(
28832887
val
28842888
* torch.ones(
2885-
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
2889+
acc_ops_clamp_shape,
2890+
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
28862891
)
28872892
)
28882893
.cpu()
@@ -3528,7 +3533,9 @@ def acc_ops_cumsum(
35283533
iterator = loop.add_iterator(input_val, dim, False)
35293534
data = iterator.get_output(0)
35303535
new_dims = tuple(data.shape)
3531-
zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
3536+
zero_tensor = torch.zeros(
3537+
new_dims, dtype=unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
3538+
)
35323539
zero_tensor = network.add_constant(
35333540
zero_tensor.shape, to_numpy(zero_tensor)
35343541
).get_output(0)
@@ -3671,7 +3678,7 @@ def acc_ops_new_ones(
36713678
dtype_val = kwargs.get("dtype")
36723679
if dtype_val is None:
36733680
dtype_val = input_val.dtype
3674-
dtype_val = torch_dtype_from_trt(dtype_val)
3681+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
36753682

36763683
device_val = kwargs.get("device")
36773684
assert (
@@ -3695,7 +3702,7 @@ def acc_ops_new_empty(
36953702
dtype_val = kwargs.get("dtype")
36963703
if dtype_val is None:
36973704
dtype_val = input_val.dtype
3698-
dtype_val = torch_dtype_from_trt(dtype_val)
3705+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
36993706

37003707
device_val = kwargs.get("device")
37013708
assert (

py/torch_tensorrt/fx/converters/aten_ops_converters.py

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
22-
2321
from .converter_utils import * # noqa: F403
2422
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2523
from torch_tensorrt.fx.converters.impl import activation

0 commit comments

Comments
 (0)