Skip to content

Commit 9cb57c2

Browse files
authored
fix: Repair graph naming for FX legacy suite (#2111)
1 parent 324766f commit 9cb57c2

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
from copy import deepcopy
5+
from packaging import version
56

67
import torch
78
import torch.fx as fx
@@ -42,6 +43,9 @@ def forward(self, x, y):
4243
%reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
4344
return reshape
4445
"""
46+
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
47+
expected_graph = expected_graph.replace("num_users", "#users")
48+
4549
assert (
4650
str(mod_fixed.graph).strip() == expected_graph.strip()
4751
), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}"

py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Owner(s): ["oncall: gpu_enablement"]
22

33
import logging
4+
import torch
5+
from packaging import version
46

57
import torch.fx as fx
68
import torch.nn as nn
@@ -54,6 +56,10 @@ def is_leaf_module(self, m, qn):
5456
%add : [num_users=1] = call_function[target=operator.add](args = (%getitem, %getitem_1), kwargs = {})
5557
return add
5658
""".strip()
59+
60+
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
61+
ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users")
62+
5763
assert (
5864
ttop_graph_expected == ttop_graph_actual
5965
), f"Unexpected ttop graph: {ttop_graph_actual}"
@@ -64,6 +70,10 @@ def is_leaf_module(self, m, qn):
6470
%x : [num_users=1] = placeholder[target=x]
6571
return (x,)
6672
""".strip()
73+
74+
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
75+
ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users")
76+
6777
assert (
6878
ttop_a_graph_expected == ttop_a_graph_actual
6979
), f"Unexpected ttop.a graph: {ttop_a_graph_actual}"

0 commit comments

Comments
 (0)