Skip to content

Commit dfc4899

Browse files
committed
Move fixes into Dynamo directory
1 parent 73a0bce commit dfc4899

File tree

18 files changed

+265
-81
lines changed

18 files changed

+265
-81
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.utils._pytree as pytree
1010
from torch._dynamo.utils import detect_fake_mode
1111
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
12+
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
1313
from torch._ops import OpOverload
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo.compile import compile_module
@@ -100,7 +100,7 @@ def _pretraced_backend(
100100
+ "Returning GraphModule forward instead.",
101101
exc_info=True,
102102
)
103-
return gm.forward
103+
return gm
104104
else:
105105
logger.critical(
106106
"Halting compilation on build failure since "
@@ -114,6 +114,13 @@ def _pretraced_backend(
114114

115115
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
116116
def constant_fold(gm: torch.fx.GraphModule) -> Any:
117+
"""Adapted from:
118+
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119+
120+
Folds constants in the graph module, not skipping constructors
121+
122+
Modifies the graph in-place and replaces node with constants
123+
"""
117124
cf = ConstantFolder(gm, skip_constructors=False)
118125
cf.run()
119126

@@ -141,10 +148,13 @@ def aot_export_for_compile(
141148
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
142149
) -> torch.fx.GraphModule:
143150
"""Adapted from:
144-
https://github.com/pytorch/pytorch/blob/054f3f1d8f9eb63ef8437991eba5b8f2aeee920f/torch/_functorch/aot_autograd.py#L4133-L4134
151+
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
145152
146153
Removed check for input aliasing in resultant subgraph - TRT is functional-only
154+
155+
Exports the function to ATen for torch compile
147156
"""
157+
# Trace function with input arguments and decompositions
148158
with torch.no_grad():
149159
fx_g, metadata, in_spec, out_spec = _aot_export_function(
150160
func,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
361361
outputs = (args[0],)
362362

363363
for output_idx in range(len(outputs)):
364-
from torch_tensorrt.fx.converters import get_trt_tensor
364+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
365365

366366
output = outputs[output_idx]
367367

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+79-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import functools
22
import logging
33
import re
4-
from typing import Any, List, Optional, Tuple
4+
from typing import Any, List, Optional, Tuple, Union
55

6+
import numpy as np
67
import tensorrt as trt
78
import torch
89
from torch.fx.node import Target
910
from torch_tensorrt.fx.converters.converter_utils import (
1011
Frameworks,
1112
get_axes_for_reduce_op,
13+
to_numpy,
1214
unified_dtype_converter,
1315
)
1416
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
@@ -185,11 +187,85 @@ def extend_attr_to_tuple(
185187

186188
if isinstance(val, list):
187189
val = tuple(val)
188-
return val
190+
191+
if isinstance(val, tuple):
192+
return val
193+
else:
194+
raise AssertionError(f"Could not extend attribute {val}")
189195

190196

191-
def cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor):
197+
def cast_int_or_float_to_bool(
198+
network: TRTNetwork, name: str, tensor: TRTTensor
199+
) -> TRTTensor:
192200
if tensor.dtype != trt.bool:
193201
return cast_trt_tensor(network, tensor, trt.bool, name)
194202

195203
return tensor
204+
205+
206+
def create_constant(
207+
network: TRTNetwork,
208+
value: Union[int, float, np.ndarray, torch.Tensor],
209+
name: str,
210+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]],
211+
) -> TRTTensor:
212+
"""
213+
Add a TensorRT constant layer whose value is `value` to `network`.
214+
Args:
215+
network (TRTNetwork): A TensorRT network to which we want to add
216+
a constant layer.
217+
value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array,
218+
or a PyTorch tensor that will be used as value of the added TensorRT Constant layer.
219+
name (str): Name of the added TensorRT Constant layer.
220+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
221+
If a dtype is given, we will convert the type of the given `value` to this dtype.
222+
Returns:
223+
A TensorRT ITensor that represents the given value.
224+
"""
225+
constant = network.add_constant(
226+
(1,) if isinstance(value, (int, float)) else value.shape,
227+
to_numpy(value, dtype).copy(),
228+
)
229+
constant.name = name
230+
return constant.get_output(0)
231+
232+
233+
def get_trt_tensor(
234+
network: TRTNetwork,
235+
input_val: Any,
236+
name: str,
237+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
238+
) -> TRTTensor:
239+
"""
240+
Given a value of random type, we try to convert it to a TensorRT ITensor.
241+
An runtime error is raised if we're not able to do that.
242+
Args:
243+
network (TRTNetwork): A TensorRT network. If we want to
244+
add a TensorRT Constant layer, we will add it to this network.
245+
input_val (Any): An value that we want to convert to a TensorRT ITensor.
246+
name (str): The name of the created TensorRT Constant layer if there's
247+
one.
248+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
249+
If dtype is provided, the given value will be converted to this dtype.
250+
Returns:
251+
A TensorRT ITensor that represents the given value.
252+
"""
253+
# TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later
254+
# This is useful for logical operations which require input to be bool type
255+
if isinstance(input_val, bool):
256+
input_val = int(input_val)
257+
elif isinstance(input_val, torch.Tensor) and (
258+
input_val.dtype == torch.bool or input_val.dtype == torch.int64
259+
):
260+
input_val = input_val.to(torch.int32)
261+
elif isinstance(input_val, np.ndarray) and (
262+
input_val.dtype == np.bool_ or input_val.dtype == np.int64
263+
):
264+
input_val = input_val.astype(np.int32)
265+
266+
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)):
267+
return create_constant(network, input_val, name, dtype)
268+
elif isinstance(input_val, TRTTensor):
269+
return input_val
270+
else:
271+
raise AssertionError(f"Cannot convert {input_val} to TRT constant")

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable
7-
from torch_tensorrt.dynamo.conversion.impl.slice import expand
8-
from torch_tensorrt.fx.converters.converter_utils import (
9-
broadcast,
7+
from torch_tensorrt.dynamo.conversion.converter_utils import (
8+
broadcastable,
109
get_trt_tensor,
11-
set_layer_name,
1210
)
11+
from torch_tensorrt.dynamo.conversion.impl.slice import expand
12+
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
1313
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1414

15-
import tensorrt as trt
16-
1715

1816
def where(
1917
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/impl/conv.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import torch
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo.conversion import impl
10-
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
extend_attr_to_tuple,
12+
get_trt_tensor,
13+
)
1114
from torch_tensorrt.fx.converters.converter_utils import (
1215
SourceIR,
1316
get_dyn_range,
14-
get_trt_tensor,
1517
has_dynamic_shape,
1618
mark_as_int8_layer,
1719
set_layer_name,
@@ -27,8 +29,8 @@ def convNd(
2729
name: str,
2830
is_conv1d: bool,
2931
input: TRTTensor,
30-
weight: Union[TRTTensor, torch.Tensor],
31-
bias: Optional[Union[TRTTensor, torch.Tensor]],
32+
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
33+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
3234
stride: Optional[Union[int, Sequence[int]]],
3335
padding: Optional[Union[int, Sequence[int]]],
3436
dilation: Optional[Union[int, Sequence[int]]],
@@ -97,19 +99,28 @@ def convNd(
9799
if isinstance(bias, TRTTensor):
98100
conv_layer.set_input(2, bias)
99101

102+
# Cast certain fields to tuples, in accordance with TRT requirements
103+
padding = (padding,) if isinstance(padding, int) else padding
104+
stride = (stride,) if isinstance(stride, int) else stride
105+
dilation = (dilation,) if isinstance(dilation, int) else dilation
106+
100107
# Expand parameters manually for Conv1D computations
101108
if is_conv1d:
102-
padding = tuple(padding) + (0,)
103-
stride = extend_attr_to_tuple(stride, 2)
104-
dilation = extend_attr_to_tuple(dilation, 2)
109+
padding = (tuple(padding) + (0,)) if padding is not None else padding
110+
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
111+
dilation = (
112+
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
113+
)
105114

106115
set_layer_name(conv_layer, target, name, source_ir)
107116

108117
# Set relevant attributes of convolution layer
109-
conv_layer.padding_nd = padding
110-
conv_layer.stride_nd = stride
111-
conv_layer.dilation_nd = dilation
112-
118+
if padding is not None:
119+
conv_layer.padding_nd = padding
120+
if stride is not None:
121+
conv_layer.stride_nd = stride
122+
if dilation is not None:
123+
conv_layer.dilation_nd = dilation
113124
if groups is not None:
114125
conv_layer.num_groups = groups
115126

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import torch
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
10-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
cast_trt_tensor,
12+
get_trt_tensor,
13+
)
1114
from torch_tensorrt.fx.converters.converter_utils import (
1215
broadcast,
13-
get_trt_tensor,
1416
set_layer_name,
1517
squeeze_left,
1618
)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,14 @@
77
from torch_tensorrt.dynamo.conversion.converter_utils import (
88
cast_int_int_div_trt_tensor,
99
cast_int_or_float_to_bool,
10+
get_trt_tensor,
1011
)
1112
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1213
convert_binary_elementwise,
1314
)
1415
from torch_tensorrt.dynamo.conversion.impl.unary import sign
1516
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
16-
from torch_tensorrt.fx.converters.converter_utils import (
17-
get_trt_tensor,
18-
set_layer_name,
19-
squeeze_left,
20-
)
17+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left
2118
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
2219
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2320

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor, set_layer_name
6+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
78
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
89

910

py/torch_tensorrt/dynamo/conversion/impl/linear.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch.fx.node import Target
77
from torch_tensorrt.dynamo.conversion import impl
8-
from torch_tensorrt.fx.converters.converter_utils import SourceIR, get_trt_tensor
8+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
99
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1010

1111

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.fx.converters.converter_utils import (
7-
broadcast,
8-
get_trt_tensor,
9-
set_layer_name,
10-
)
6+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
7+
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
118
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
129
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1310

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
56
from torch_tensorrt.fx.converters.converter_utils import (
67
get_positive_dim,
7-
get_trt_tensor,
88
set_layer_name,
99
)
1010
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
124124

125125
# Replace all original node uses and clean up graph
126126
n.replace_all_uses_with(new_node)
127-
gm.graph.eliminate_dead_code()
128127
gm.graph.lint()
129128
gm.recompile()
130129

@@ -138,7 +137,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
138137
continue
139138

140139
# Perform cleanup and recompilation before returning module
141-
gm.graph.eliminate_dead_code()
142140
gm.graph.lint()
143141
gm.recompile()
144142

0 commit comments

Comments
 (0)