Skip to content

Commit 9cde6af

Browse files
committed
Move fixes into Dynamo directory
1 parent 73a0bce commit 9cde6af

19 files changed

+307
-129
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.utils._pytree as pytree
1010
from torch._dynamo.utils import detect_fake_mode
1111
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
12+
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
1313
from torch._ops import OpOverload
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo.compile import compile_module
@@ -100,7 +100,7 @@ def _pretraced_backend(
100100
+ "Returning GraphModule forward instead.",
101101
exc_info=True,
102102
)
103-
return gm.forward
103+
return gm
104104
else:
105105
logger.critical(
106106
"Halting compilation on build failure since "
@@ -114,6 +114,13 @@ def _pretraced_backend(
114114

115115
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
116116
def constant_fold(gm: torch.fx.GraphModule) -> Any:
117+
"""Adapted from:
118+
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119+
120+
Folds constants in the graph module, not skipping constructors
121+
122+
Modifies the graph in-place and replaces node with constants
123+
"""
117124
cf = ConstantFolder(gm, skip_constructors=False)
118125
cf.run()
119126

@@ -141,10 +148,13 @@ def aot_export_for_compile(
141148
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
142149
) -> torch.fx.GraphModule:
143150
"""Adapted from:
144-
https://github.com/pytorch/pytorch/blob/054f3f1d8f9eb63ef8437991eba5b8f2aeee920f/torch/_functorch/aot_autograd.py#L4133-L4134
151+
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
145152
146153
Removed check for input aliasing in resultant subgraph - TRT is functional-only
154+
155+
Exports the function to ATen for torch compile
147156
"""
157+
# Trace function with input arguments and decompositions
148158
with torch.no_grad():
149159
fx_g, metadata, in_spec, out_spec = _aot_export_function(
150160
func,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
361361
outputs = (args[0],)
362362

363363
for output_idx in range(len(outputs)):
364-
from torch_tensorrt.fx.converters import get_trt_tensor
364+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
365365

366366
output = outputs[output_idx]
367367

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+42-48
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def aten_ops_fmod(
9494
return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
9595

9696

97-
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
97+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
9898
def aten_ops_relu(
9999
network: TRTNetwork,
100100
target: Target,
@@ -111,7 +111,7 @@ def aten_ops_relu(
111111
)
112112

113113

114-
@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default)
114+
@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) # type: ignore[misc]
115115
def aten_ops_sigmoid(
116116
network: TRTNetwork,
117117
target: Target,
@@ -128,7 +128,7 @@ def aten_ops_sigmoid(
128128
)
129129

130130

131-
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default)
131+
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc]
132132
def aten_ops_tanh(
133133
network: TRTNetwork,
134134
target: Target,
@@ -145,7 +145,7 @@ def aten_ops_tanh(
145145
)
146146

147147

148-
@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default)
148+
@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) # type: ignore[misc]
149149
def aten_ops_leaky_relu(
150150
network: TRTNetwork,
151151
target: Target,
@@ -163,7 +163,7 @@ def aten_ops_leaky_relu(
163163
)
164164

165165

166-
@dynamo_tensorrt_converter(torch.ops.aten.elu.default)
166+
@dynamo_tensorrt_converter(torch.ops.aten.elu.default) # type: ignore[misc]
167167
def aten_ops_elu(
168168
network: TRTNetwork,
169169
target: Target,
@@ -182,7 +182,7 @@ def aten_ops_elu(
182182
)
183183

184184

185-
@dynamo_tensorrt_converter(torch.ops.aten.softplus.default)
185+
@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) # type: ignore[misc]
186186
def aten_ops_softplus(
187187
network: TRTNetwork,
188188
target: Target,
@@ -200,7 +200,7 @@ def aten_ops_softplus(
200200
)
201201

202202

203-
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
203+
@dynamo_tensorrt_converter(torch.ops.aten.clip.default) # type: ignore[misc]
204204
def aten_ops_clip(
205205
network: TRTNetwork,
206206
target: Target,
@@ -219,7 +219,7 @@ def aten_ops_clip(
219219
)
220220

221221

222-
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
222+
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) # type: ignore[misc]
223223
def aten_ops_hard_sigmoid(
224224
network: TRTNetwork,
225225
target: Target,
@@ -296,26 +296,20 @@ def aten_ops_rsqrt(
296296
)
297297

298298

299-
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
299+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default) # type: ignore[misc]
300300
def aten_ops_neg(
301301
network: TRTNetwork,
302302
target: Target,
303303
args: Tuple[Argument, ...],
304304
kwargs: Dict[str, Argument],
305305
name: str,
306306
) -> Union[TRTTensor, Sequence[TRTTensor]]:
307-
input_val = args[0]
308-
if (isinstance(input_val, TRTTensor)) and (
309-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
310-
):
311-
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
312-
313307
return impl.unary.neg(
314308
network,
315309
target,
316310
SourceIR.ATEN,
317311
name,
318-
input_val,
312+
args[0],
319313
)
320314

321315

@@ -503,7 +497,7 @@ def aten_ops_clone(
503497
)
504498

505499

506-
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
500+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc]
507501
def aten_ops_expand(
508502
network: TRTNetwork,
509503
target: Target,
@@ -533,7 +527,7 @@ def amax_param_validator(amax_node: Node) -> bool:
533527

534528
@dynamo_tensorrt_converter(
535529
torch.ops.aten.amax.default, capability_validator=amax_param_validator
536-
)
530+
) # type: ignore[misc]
537531
def aten_ops_amax(
538532
network: TRTNetwork,
539533
target: Target,
@@ -552,8 +546,8 @@ def aten_ops_amax(
552546
)
553547

554548

555-
@dynamo_tensorrt_converter(torch.ops.aten.sum.default)
556-
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList)
549+
@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc]
550+
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc]
557551
def aten_ops_sum(
558552
network: TRTNetwork,
559553
target: Target,
@@ -946,8 +940,8 @@ def aten_ops_isinf(
946940
)
947941

948942

949-
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
950-
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
943+
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) # type: ignore[misc]
944+
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) # type: ignore[misc]
951945
def aten_ops_add(
952946
network: TRTNetwork,
953947
target: Target,
@@ -978,8 +972,8 @@ def aten_ops_add(
978972
)
979973

980974

981-
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
982-
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
975+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) # type: ignore[misc]
976+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) # type: ignore[misc]
983977
def aten_ops_mul(
984978
network: TRTNetwork,
985979
target: Target,
@@ -997,7 +991,7 @@ def aten_ops_mul(
997991
)
998992

999993

1000-
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
994+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc]
1001995
def aten_ops_max(
1002996
network: TRTNetwork,
1003997
target: Target,
@@ -1015,7 +1009,7 @@ def aten_ops_max(
10151009
)
10161010

10171011

1018-
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
1012+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc]
10191013
def aten_ops_min(
10201014
network: TRTNetwork,
10211015
target: Target,
@@ -1033,8 +1027,8 @@ def aten_ops_min(
10331027
)
10341028

10351029

1036-
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
1037-
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
1030+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) # type: ignore[misc]
1031+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) # type: ignore[misc]
10381032
def aten_ops_sub(
10391033
network: TRTNetwork,
10401034
target: Target,
@@ -1065,10 +1059,10 @@ def aten_ops_sub(
10651059
)
10661060

10671061

1068-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1069-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
1070-
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
1071-
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
1062+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc]
1063+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
1064+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc]
1065+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc]
10721066
def aten_ops_div(
10731067
network: TRTNetwork,
10741068
target: Target,
@@ -1111,9 +1105,9 @@ def aten_ops_div(
11111105
)
11121106

11131107

1114-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
1115-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
1116-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
1108+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) # type: ignore[misc]
1109+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) # type: ignore[misc]
1110+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) # type: ignore[misc]
11171111
def aten_ops_pow(
11181112
network: TRTNetwork,
11191113
target: Target,
@@ -1131,8 +1125,8 @@ def aten_ops_pow(
11311125
)
11321126

11331127

1134-
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
1135-
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
1128+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) # type: ignore[misc]
1129+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) # type: ignore[misc]
11361130
def aten_ops_floor_div(
11371131
network: TRTNetwork,
11381132
target: Target,
@@ -1150,7 +1144,7 @@ def aten_ops_floor_div(
11501144
)
11511145

11521146

1153-
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
1147+
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) # type: ignore[misc]
11541148
def aten_ops_logical_and(
11551149
network: TRTNetwork,
11561150
target: Target,
@@ -1168,7 +1162,7 @@ def aten_ops_logical_and(
11681162
)
11691163

11701164

1171-
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
1165+
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) # type: ignore[misc]
11721166
def aten_ops_logical_or(
11731167
network: TRTNetwork,
11741168
target: Target,
@@ -1186,7 +1180,7 @@ def aten_ops_logical_or(
11861180
)
11871181

11881182

1189-
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
1183+
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) # type: ignore[misc]
11901184
def aten_ops_logical_xor(
11911185
network: TRTNetwork,
11921186
target: Target,
@@ -1204,8 +1198,8 @@ def aten_ops_logical_xor(
12041198
)
12051199

12061200

1207-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1208-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
1201+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
1202+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
12091203
def aten_ops_equal(
12101204
network: TRTNetwork,
12111205
target: Target,
@@ -1223,8 +1217,8 @@ def aten_ops_equal(
12231217
)
12241218

12251219

1226-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1227-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
1220+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
1221+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
12281222
def aten_ops_greater(
12291223
network: TRTNetwork,
12301224
target: Target,
@@ -1242,8 +1236,8 @@ def aten_ops_greater(
12421236
)
12431237

12441238

1245-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1246-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
1239+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
1240+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
12471241
def aten_ops_less(
12481242
network: TRTNetwork,
12491243
target: Target,
@@ -1267,7 +1261,7 @@ def conv_param_validator(conv_node: Node) -> bool:
12671261

12681262
@dynamo_tensorrt_converter(
12691263
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
1270-
)
1264+
) # type: ignore[misc]
12711265
def aten_ops_convolution(
12721266
network: TRTNetwork,
12731267
target: Target,
@@ -1291,7 +1285,7 @@ def aten_ops_convolution(
12911285
)
12921286

12931287

1294-
@dynamo_tensorrt_converter(torch.ops.aten.linear.default)
1288+
@dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc]
12951289
def aten_ops_linear(
12961290
network: TRTNetwork,
12971291
target: Target,

0 commit comments

Comments
 (0)