Skip to content

Commit 91fcea4

Browse files
authored
feat: Add Selective ATen decompositions (#2173)
1 parent f70574e commit 91fcea4

File tree

6 files changed

+314
-24
lines changed

6 files changed

+314
-24
lines changed

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
TRUNCATE_LONG_AND_DOUBLE = False
1212
USE_PYTHON_RUNTIME = False
1313
USE_FAST_PARTITIONER = True
14+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False

py/torch_tensorrt/dynamo/_settings.py

+23
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch_tensorrt.dynamo._defaults import (
66
DEBUG,
7+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
78
MAX_AUX_STREAMS,
89
MIN_BLOCK_SIZE,
910
OPTIMIZATION_LEVEL,
@@ -19,6 +20,27 @@
1920

2021
@dataclass
2122
class CompilationSettings:
23+
"""Compilation settings for Torch-TensorRT Dynamo Paths
24+
25+
Args:
26+
precision (torch.dtype): Model Layer precision
27+
debug (bool): Whether to print out verbose debugging information
28+
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
29+
min_block_size (int): Minimum number of operators per TRT-Engine Block
30+
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
31+
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
32+
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
33+
version_compatible (bool): Provide version forward-compatibility for engine plan files
34+
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
35+
searching for more optimization options. TRT defaults to 3
36+
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
37+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
38+
argument as None
39+
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
40+
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
41+
or only a selected subset of them
42+
"""
43+
2244
precision: torch.dtype = PRECISION
2345
debug: bool = DEBUG
2446
workspace_size: int = WORKSPACE_SIZE
@@ -31,3 +53,4 @@ class CompilationSettings:
3153
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3254
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
3355
use_fast_partitioner: bool = USE_FAST_PARTITIONER
56+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS

py/torch_tensorrt/dynamo/backend/backends.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def aot_torch_tensorrt_aten_backend(
5656
gm,
5757
sample_inputs,
5858
fw_compiler=make_boxed_compiler(custom_backend),
59-
decompositions=get_decompositions(),
59+
decompositions=get_decompositions(settings.enable_experimental_decompositions),
6060
)
6161

6262

py/torch_tensorrt/dynamo/compile.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch_tensorrt.dynamo import CompilationSettings, partitioning
1414
from torch_tensorrt.dynamo._defaults import (
1515
DEBUG,
16+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1617
MAX_AUX_STREAMS,
1718
MIN_BLOCK_SIZE,
1819
OPTIMIZATION_LEVEL,
@@ -61,6 +62,7 @@ def compile(
6162
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
6263
use_python_runtime: bool = USE_PYTHON_RUNTIME,
6364
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
65+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
6466
**kwargs: Any,
6567
) -> torch.fx.GraphModule:
6668
if debug:
@@ -71,9 +73,10 @@ def compile(
7173

7274
logger.warning(
7375
"The Dynamo backend is an experimental feature, for which only the "
74-
+ "following arguments are supported: "
75-
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
76-
+ "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}"
76+
"following arguments are supported: "
77+
"{enabled_precisions, debug, workspace_size, min_block_size, "
78+
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
79+
"enable_experimental_decompositions}"
7780
)
7881

7982
if not isinstance(inputs, collections.abc.Sequence):
@@ -114,6 +117,7 @@ def compile(
114117
"use_python_runtime": use_python_runtime,
115118
"truncate_long_and_double": truncate_long_and_double,
116119
"use_fast_partitioner": use_fast_partitioner,
120+
"enable_experimental_decompositions": enable_experimental_decompositions,
117121
}
118122

119123
settings = CompilationSettings(**compilation_options)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from typing import Any, Callable, Dict, Set
2+
3+
import torch
4+
from torch._decomp import core_aten_decompositions
5+
from torch._decomp import get_decompositions as get_torch_decompositions
6+
from torch._ops import OpOverload
7+
8+
aten = torch.ops.aten
9+
10+
_core_aten_decompositions: Dict[
11+
OpOverload, Callable[[Any], Any]
12+
] = core_aten_decompositions()
13+
torch_enabled_decompositions: Set[OpOverload] = {
14+
aten._adaptive_avg_pool2d_backward,
15+
aten.addcdiv,
16+
aten.addcdiv_,
17+
aten.addcmul,
18+
aten.addcmul_,
19+
aten.addr,
20+
aten.aminmax,
21+
aten.arange.default,
22+
aten.arange.start,
23+
aten.avg_pool2d_backward,
24+
aten.binary_cross_entropy,
25+
aten.binary_cross_entropy_backward,
26+
aten.binary_cross_entropy_with_logits,
27+
aten.celu,
28+
aten.col2im,
29+
aten.count_nonzero,
30+
aten.cudnn_batch_norm,
31+
aten.cudnn_batch_norm_backward,
32+
aten.deg2rad,
33+
aten.detach,
34+
aten.diag_embed,
35+
aten.diagonal_backward,
36+
aten.dot,
37+
aten.elu,
38+
aten.elu_backward,
39+
aten._embedding_bag,
40+
aten.embedding_dense_backward,
41+
aten._euclidean_dist.default,
42+
aten.expand_as,
43+
aten.eye,
44+
aten.fill,
45+
aten.frac,
46+
aten._fused_moving_avg_obs_fq_helper,
47+
aten.gelu,
48+
aten.gelu_backward,
49+
aten.glu_backward,
50+
aten.grid_sampler_2d,
51+
aten.hardshrink,
52+
aten.hardshrink_backward,
53+
aten.hardsigmoid,
54+
aten.hardsigmoid_backward,
55+
aten.hardswish,
56+
aten.hardswish_,
57+
aten.hardswish_backward,
58+
aten.hardtanh,
59+
aten.hardtanh_,
60+
aten.hardtanh_backward,
61+
aten.heaviside,
62+
aten.huber_loss,
63+
aten.huber_loss_backward,
64+
aten.im2col,
65+
aten.index_add,
66+
aten.index_add_,
67+
aten.index_copy,
68+
aten.index_copy_,
69+
aten.index_fill,
70+
aten.index_fill_,
71+
aten.index_select,
72+
aten.isneginf,
73+
aten.isposinf,
74+
aten.l1_loss,
75+
aten.leaky_relu,
76+
aten.leaky_relu_,
77+
aten.leaky_relu_backward,
78+
aten.lerp,
79+
aten.linspace,
80+
aten.logaddexp,
81+
aten.logaddexp2,
82+
aten.logit,
83+
aten.logit_backward,
84+
aten.log_sigmoid_backward,
85+
aten.log_sigmoid_forward,
86+
aten._log_softmax,
87+
aten._log_softmax_backward_data,
88+
aten.logspace,
89+
aten.logsumexp.default,
90+
aten.masked_fill,
91+
aten.masked_fill_,
92+
aten.max_pool2d_with_indices_backward,
93+
aten.mish,
94+
aten.mse_loss,
95+
aten.mse_loss_backward,
96+
aten.mv,
97+
aten.mvlgamma,
98+
aten.nansum,
99+
aten.nan_to_num,
100+
aten.narrow,
101+
# TODO: Disable the below operators once freezing is done
102+
aten.native_batch_norm,
103+
aten.native_batch_norm_backward,
104+
aten._native_batch_norm_legit,
105+
aten._native_batch_norm_legit_functional,
106+
aten._native_batch_norm_legit_no_training,
107+
aten.native_dropout_backward,
108+
aten.native_group_norm,
109+
aten.native_group_norm_backward,
110+
aten.native_layer_norm,
111+
aten.native_layer_norm_backward,
112+
aten.new_empty,
113+
aten.new_full,
114+
aten.new_ones,
115+
aten.new_zeros,
116+
aten.nll_loss_backward,
117+
aten.nll_loss_forward,
118+
aten.norm,
119+
aten.ones,
120+
aten.ones_like,
121+
aten._prelu_kernel,
122+
aten._prelu_kernel_backward,
123+
aten._reshape_alias,
124+
aten.rad2deg,
125+
aten.renorm,
126+
aten.renorm_,
127+
aten.rot90,
128+
aten.rsub.Scalar,
129+
aten.rsub.Tensor,
130+
aten.select_backward,
131+
aten.select_scatter,
132+
aten.sgn,
133+
aten.sigmoid_backward,
134+
aten.silu,
135+
aten.silu_,
136+
aten.silu_backward,
137+
aten.sinc,
138+
aten.slice_backward,
139+
aten.smooth_l1_loss,
140+
aten.smooth_l1_loss_backward,
141+
aten.soft_margin_loss,
142+
aten.soft_margin_loss_backward,
143+
aten._softmax,
144+
aten._softmax_backward_data,
145+
aten.softplus,
146+
aten.softplus_backward,
147+
aten.softshrink,
148+
aten.softshrink_backward,
149+
aten.special_entr,
150+
aten.special_log_ndtr,
151+
aten.special_xlog1py,
152+
aten.stack,
153+
aten.t,
154+
aten.tanh_backward,
155+
aten.threshold,
156+
aten.threshold_backward,
157+
aten.trace,
158+
aten.transpose.int,
159+
aten.tril.default,
160+
aten.triu.default,
161+
aten.unfold,
162+
aten.unfold_backward,
163+
aten.unfold_copy,
164+
aten.upsample_bilinear2d,
165+
aten.upsample_bilinear2d.vec,
166+
aten.upsample_nearest2d_backward,
167+
aten.xlogy,
168+
aten.zero,
169+
aten.zero_,
170+
aten.zeros,
171+
aten.zeros_like,
172+
# Non-default convenience decompositions
173+
aten.clamp_min,
174+
aten.clamp_max,
175+
aten.linalg_vector_norm,
176+
aten.full,
177+
aten.repeat,
178+
}
179+
torch_disabled_decompositions: Set[OpOverload] = set()
180+
181+
182+
ENABLED_TORCH_DECOMPOSITIONS: Dict[
183+
OpOverload, Callable[[Any], Any]
184+
] = get_torch_decompositions(torch_enabled_decompositions)
185+
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
186+
187+
188+
def check_decomp_set_invariants() -> None:
189+
"""Validates no overlap between enabled and disabled decomposition sets"""
190+
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions)
191+
192+
if overlap:
193+
raise AssertionError(
194+
f"Detected {overlap} registered in both torch_enabled_decompositions "
195+
"and torch_disabled_decompositions. Ensure all operator(s) are in "
196+
"at most one of the two sets."
197+
)
198+
199+
200+
check_decomp_set_invariants()

0 commit comments

Comments
 (0)