Skip to content

Commit c9f06fc

Browse files
committed
feat: Add support for general-purpose function acceleration in Dynamo [6 / x] (#1980)
1 parent e642b62 commit c9f06fc

File tree

9 files changed

+209
-67
lines changed

9 files changed

+209
-67
lines changed

Diff for: py/torch_tensorrt/dynamo/backend/backends.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
get_decompositions,
1010
)
1111
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
12-
pre_aot_module_replacement,
12+
pre_aot_substitutions,
1313
)
1414
from torch_tensorrt.dynamo.backend.lowering._partition import (
1515
partition,
@@ -45,7 +45,7 @@ def aot_torch_tensorrt_aten_backend(
4545
)
4646

4747
# Perform Pre-AOT Lowering for Module-Level Replacement
48-
gm = pre_aot_module_replacement(gm)
48+
gm = pre_aot_substitutions(gm)
4949

5050
# Invoke AOTAutograd to translate operators to aten
5151
return aot_module_simplified(

Diff for: py/torch_tensorrt/dynamo/backend/lowering/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
get_decompositions,
33
)
44
from ._pre_aot_lowering import (
5-
MODULE_SUBSTITUTION_REGISTRY,
6-
module_substitution,
5+
SUBSTITUTION_REGISTRY,
6+
register_substitution,
77
)
88
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
9-
from .module_substitutions import *
9+
from .substitutions import *

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
7-
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
7+
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
88
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
99
from torch.fx.graph_module import GraphModule
1010
from torch.fx.node import _get_qualified_name
@@ -16,8 +16,8 @@
1616
logger = logging.getLogger(__name__)
1717

1818
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19-
_get_qualified_name(module.new_operator)
20-
for module in MODULE_SUBSTITUTION_REGISTRY.values()
19+
_get_qualified_name(to_replace.new_operator)
20+
for to_replace in SUBSTITUTION_REGISTRY.values()
2121
)
2222

2323

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py

+67-54
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Callable, Dict, Type
2+
from typing import Any, Callable, Dict, Optional, Type, Union
33
import torch
44
import logging
55

@@ -8,59 +8,62 @@
88

99

1010
@dataclass(frozen=True)
11-
class ModuleReplacement:
11+
class Substitution:
1212
"""Class to store key functionality for module replacement"""
1313

1414
# torch.ops.___ name for replacement function for module
1515
new_operator: torch._ops.OpOverload
1616

17-
# Function taking a containing graph, a submodule, and a 'call_module' node and returning
18-
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
17+
# Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
18+
# and returning a replacement node, with type 'call_function', or raising an Error if
19+
# incompatibility is detected
1920
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
2021
subgraph_insertion_fn: Callable[
21-
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node
22+
[torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node
2223
]
2324

2425

25-
# Dictionary mapping module to ModuleReplacement instance
26-
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict()
26+
# Dictionary mapping module to Substitution instance
27+
SUBSTITUTION_REGISTRY: Dict[
28+
Union[Type[torch.nn.Module], Callable], Substitution
29+
] = dict()
2730

2831

29-
def module_substitution(
30-
module_to_replace: Type[torch.nn.Module],
32+
def register_substitution(
33+
module_or_function_to_replace: Union[Type[torch.nn.Module], Callable],
3134
new_operator: torch._ops.OpOverload,
3235
enabled: bool = True,
3336
) -> Callable[[Any], Any]:
3437
"""Decorator to register subgraph insertion functions
3538
3639
Args:
37-
module_to_replace: nn.Module to replace
40+
module_or_function_to_replace: nn.Module or node target Callable to replace
3841
new_operator: Custom torch operator to replace with
3942
enabled: Whether the substitution is enabled or disabled
4043
Returns:
4144
torch.fx.GraphModule
4245
"""
4346

44-
def register_substitution(subgraph_insertion_fn):
47+
def enable_substitution(subgraph_insertion_fn):
4548
"""Function for use if substitution is enabled"""
46-
module_replacement = ModuleReplacement(
49+
replacement = Substitution(
4750
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
4851
)
49-
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
52+
SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement
5053
return subgraph_insertion_fn
5154

5255
def disable_substitution(subgraph_insertion_fn):
5356
"""Function for use if substitution is disabled"""
5457
return subgraph_insertion_fn
5558

56-
return register_substitution if enabled else disable_substitution
59+
return enable_substitution if enabled else disable_substitution
5760

5861

59-
def pre_aot_module_replacement(gm: torch.fx.GraphModule):
60-
"""Perform module-level graph replacement prior to AOT tracing
62+
def pre_aot_substitutions(gm: torch.fx.GraphModule):
63+
"""Perform graph substitutions prior to AOT tracing
6164
6265
Args:
63-
gm: FX GraphModule to perform module replacement on
66+
gm: FX GraphModule to perform substitution on
6467
Returns:
6568
torch.fx.GraphModule
6669
@@ -73,48 +76,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
7376

7477
# Iterate over graph nodes, extracting module calls, to check for interceptions
7578
for n in gm.graph.nodes:
79+
exists_in_registry = False
80+
to_replace = None
81+
7682
if n.op == "call_module":
77-
# Extract submodule from graph
83+
# Extract submodule from graph, validate in registry
7884
submodule = gm.get_submodule(n.target)
79-
80-
# If submodule is a member of the substitution registry, replace it
81-
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY:
82-
83-
try:
84-
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)]
85-
op, insertion_fn = (
86-
replacement.new_operator,
87-
replacement.subgraph_insertion_fn,
88-
)
89-
logger.debug(
90-
f"Replacing module of type {type(submodule)} with {op}"
85+
to_replace = type(submodule)
86+
exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
87+
elif n.op == "call_function":
88+
# Extract function from graph, validate in registry
89+
to_replace = n.target
90+
exists_in_registry = n.target in SUBSTITUTION_REGISTRY
91+
92+
# If submodule/function is a member of the substitution registry, replace it
93+
if exists_in_registry:
94+
try:
95+
replacement = SUBSTITUTION_REGISTRY[to_replace]
96+
op, insertion_fn = (
97+
replacement.new_operator,
98+
replacement.subgraph_insertion_fn,
99+
)
100+
logger.debug(f"Replacing node of type {to_replace} with {op}")
101+
102+
# Insert new node prior to older node
103+
with gm.graph.inserting_before(n):
104+
new_node = insertion_fn(
105+
gm, n, submodule if n.op == "call_module" else None
91106
)
92107

93-
# Insert new node prior to older node
94-
with gm.graph.inserting_before(n):
95-
new_node = insertion_fn(gm, submodule, n)
96-
97-
# If submodule is not a native torch.nn module, it must be manually excluded
98-
# from Dynamo tracing
99-
if not type(submodule).__module__.startswith("torch.nn"):
100-
torch._dynamo.allowed_functions._allowed_function_ids.add(
101-
id(type(submodule))
102-
)
103-
104-
# Replace all original node uses and clean up graph
105-
n.replace_all_uses_with(new_node)
106-
gm.graph.eliminate_dead_code()
107-
gm.graph.lint()
108-
gm.recompile()
109-
110-
# A module replacement can fail in the event that the specific instance of the submodule cannot
111-
# be replaced
112-
except Exception:
113-
logger.debug(
114-
f"Encountered error while replacing {type(submodule)}",
115-
exc_info=True,
108+
# If submodule is not a native torch.nn module, it must be manually excluded
109+
# from Dynamo tracing
110+
if n.op == "call_module" and not type(submodule).__module__.startswith(
111+
"torch.nn"
112+
):
113+
torch._dynamo.allowed_functions._allowed_function_ids.add(
114+
id(to_replace)
116115
)
117-
continue
116+
117+
# Replace all original node uses and clean up graph
118+
n.replace_all_uses_with(new_node)
119+
gm.graph.eliminate_dead_code()
120+
gm.graph.lint()
121+
gm.recompile()
122+
123+
# A replacement can fail in the event that the specific instance of the submodule/function
124+
# cannot be replaced
125+
except Exception:
126+
logger.debug(
127+
f"Encountered error while replacing {to_replace}",
128+
exc_info=True,
129+
)
130+
continue
118131

119132
# Perform cleanup and recompilation before returning module
120133
gm.graph.eliminate_dead_code()
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .maxpool1d import *
2+
from .einsum import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Dict, Tuple
2+
import torch
3+
from torch._custom_op.impl import custom_op
4+
from torch.fx.node import Argument, Target
5+
6+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
from torch_tensorrt.dynamo.backend.lowering import register_substitution
11+
12+
13+
@custom_op(
14+
qualname="tensorrt::einsum",
15+
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
16+
)
17+
def einsum(equation, tensors):
18+
# Defines operator schema, name, namespace, and function header
19+
...
20+
21+
22+
@einsum.impl("cpu")
23+
@einsum.impl("cuda")
24+
@einsum.impl_abstract()
25+
def einsum_generic(
26+
*args,
27+
**kwargs,
28+
):
29+
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
30+
return torch.einsum(
31+
*args,
32+
**kwargs,
33+
)
34+
35+
36+
@tensorrt_converter(torch.ops.tensorrt.einsum.default)
37+
def aten_ops_einsum(
38+
network: TRTNetwork,
39+
target: Target,
40+
args: Tuple[Argument, ...],
41+
kwargs: Dict[str, Argument],
42+
name: str,
43+
) -> TRTTensor:
44+
# Defines converter replacing the default operator for this function
45+
for input_trt in args[1]:
46+
if not isinstance(input_trt, TRTTensor):
47+
raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}")
48+
49+
einsum_layer = network.add_einsum(inputs=args[1], equation=args[0])
50+
51+
set_layer_name(einsum_layer, target, name)
52+
return einsum_layer.get_output(0)
53+
54+
55+
@register_substitution(torch.einsum, torch.ops.tensorrt.einsum)
56+
def einsum_insertion_fn(
57+
gm: torch.fx.GraphModule,
58+
node: torch.fx.Node,
59+
_unused: None = None,
60+
) -> torch.fx.Node:
61+
equation = node.args[0]
62+
63+
# Ensure inputs is a list of (Tensor) arguments
64+
if isinstance(node.args[1], (tuple, list)):
65+
inputs = node.args[1]
66+
else:
67+
inputs = node.args[1:]
68+
69+
assert (
70+
1 <= len(inputs) <= 2
71+
), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors"
72+
73+
# Ensure the input is formatted as an equation and
74+
new_node = gm.graph.call_function(
75+
torch.ops.tensorrt.einsum,
76+
args=(equation, inputs),
77+
kwargs=node.kwargs,
78+
)
79+
80+
return new_node

Diff for: py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py renamed to py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_tensorrt.fx.converters import acc_ops_converters
88
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
99

10-
from torch_tensorrt.dynamo.backend.lowering import module_substitution
10+
from torch_tensorrt.dynamo.backend.lowering import register_substitution
1111

1212

1313
# This file serves as an example and a tutorial for excluding custom modules from
@@ -71,9 +71,11 @@ def maxpool1d_generic(
7171
# "bias": bias,
7272
# ...
7373
#
74-
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
74+
@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
7575
def maxpool1d_insertion_fn(
76-
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
76+
gm: torch.fx.GraphModule,
77+
node: torch.fx.Node,
78+
submodule: torch.nn.Module,
7779
) -> torch.fx.Node:
7880
# Defines insertion function for new node
7981
new_node = gm.graph.call_function(

Diff for: py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py

+46
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,51 @@ def forward(self, x):
5151
)
5252

5353

54+
class TestEinsum(TestCase):
55+
def test_pre_aot_lowering_einsum(self):
56+
class Einsum(torch.nn.Module):
57+
def forward(self, x, y):
58+
return torch.einsum("ij,ji->ij", x, y)
59+
60+
# Operations expected to be included in the traced graph after decompositions
61+
expected_ops = {torch.ops.tensorrt.einsum.default}
62+
63+
inputs = [
64+
torch.rand(
65+
16,
66+
16,
67+
).cuda(),
68+
torch.rand(
69+
16,
70+
16,
71+
).cuda(),
72+
]
73+
74+
fx_graph = torch.fx.symbolic_trace(Einsum())
75+
_, expected_ops_unseen = lower_graph_testing(
76+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
77+
)
78+
79+
self.assertEquals(
80+
len(expected_ops_unseen),
81+
0,
82+
f"The following expected ops were not encountered: {expected_ops_unseen}",
83+
)
84+
85+
torch._dynamo.reset()
86+
87+
# Validate that the results between Torch and Torch-TRT are similar
88+
optimized_model = compile(
89+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
90+
)
91+
optimized_model_results = optimized_model(*inputs).detach().cpu()
92+
torch_model_results = fx_graph(*inputs).detach().cpu()
93+
94+
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
95+
self.assertAlmostEqual(
96+
max_diff, 0, f"Einsum TRT outputs don't match with the original model."
97+
)
98+
99+
54100
if __name__ == "__main__":
55101
run_tests()

0 commit comments

Comments
 (0)