Skip to content

Commit a3dd0c8

Browse files
committed
fix: Add support for general-purpose exclusion
- Add functionality for advanced exclusion of both function and module-type nodes in Torch-TRT - Add sample exclusion for `torch.einsum` function which can be accelerated as a single unit via TRT - Add utilities and improve module and function-level exclusion mechanisms - Add test cases for new exclusion mechanism
1 parent 80b898a commit a3dd0c8

File tree

9 files changed

+208
-67
lines changed

9 files changed

+208
-67
lines changed

Diff for: py/torch_tensorrt/dynamo/backend/backends.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
get_decompositions,
1111
)
1212
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
13-
pre_aot_module_replacement,
13+
pre_aot_substitutions,
1414
)
1515
from torch_tensorrt.dynamo.backend.lowering._partition import (
1616
partition,
@@ -70,7 +70,7 @@ def aot_torch_tensorrt_aten_backend(
7070
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))
7171

7272
# Perform Pre-AOT Lowering for Module-Level Replacement
73-
gm = pre_aot_module_replacement(gm)
73+
gm = pre_aot_substitutions(gm)
7474

7575
logger.debug("Post-module replacement graph:\n" + str(gm.graph))
7676

Diff for: py/torch_tensorrt/dynamo/backend/lowering/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
get_decompositions,
33
)
44
from ._pre_aot_lowering import (
5-
MODULE_SUBSTITUTION_REGISTRY,
6-
module_substitution,
5+
SUBSTITUTION_REGISTRY,
6+
register_substitution,
77
)
88
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
9-
from .module_substitutions import *
9+
from .substitutions import *

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
7-
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
7+
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
88
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
99
from torch.fx.graph_module import GraphModule
1010
from torch.fx.node import _get_qualified_name
@@ -16,8 +16,8 @@
1616
logger = logging.getLogger(__name__)
1717

1818
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19-
_get_qualified_name(module.new_operator)
20-
for module in MODULE_SUBSTITUTION_REGISTRY.values()
19+
_get_qualified_name(to_replace.new_operator)
20+
for to_replace in SUBSTITUTION_REGISTRY.values()
2121
)
2222

2323

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py

+67-54
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Callable, Dict, Type
2+
from typing import Any, Callable, Dict, Optional, Type, Union
33
import torch
44
import logging
55

@@ -8,59 +8,62 @@
88

99

1010
@dataclass(frozen=True)
11-
class ModuleReplacement:
11+
class Substitution:
1212
"""Class to store key functionality for module replacement"""
1313

1414
# torch.ops.___ name for replacement function for module
1515
new_operator: torch._ops.OpOverload
1616

17-
# Function taking a containing graph, a submodule, and a 'call_module' node and returning
18-
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
17+
# Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
18+
# and returning a replacement node, with type 'call_function', or raising an Error if
19+
# incompatibility is detected
1920
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
2021
subgraph_insertion_fn: Callable[
21-
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node
22+
[torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node
2223
]
2324

2425

25-
# Dictionary mapping module to ModuleReplacement instance
26-
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict()
26+
# Dictionary mapping module to Substitution instance
27+
SUBSTITUTION_REGISTRY: Dict[
28+
Union[Type[torch.nn.Module], Callable], Substitution
29+
] = dict()
2730

2831

29-
def module_substitution(
30-
module_to_replace: Type[torch.nn.Module],
32+
def register_substitution(
33+
module_or_function_to_replace: Union[Type[torch.nn.Module], Callable],
3134
new_operator: torch._ops.OpOverload,
3235
enabled: bool = True,
3336
) -> Callable[[Any], Any]:
3437
"""Decorator to register subgraph insertion functions
3538
3639
Args:
37-
module_to_replace: nn.Module to replace
40+
module_or_function_to_replace: nn.Module or node target Callable to replace
3841
new_operator: Custom torch operator to replace with
3942
enabled: Whether the substitution is enabled or disabled
4043
Returns:
4144
torch.fx.GraphModule
4245
"""
4346

44-
def register_substitution(subgraph_insertion_fn):
47+
def enable_substitution(subgraph_insertion_fn):
4548
"""Function for use if substitution is enabled"""
46-
module_replacement = ModuleReplacement(
49+
replacement = Substitution(
4750
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
4851
)
49-
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
52+
SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement
5053
return subgraph_insertion_fn
5154

5255
def disable_substitution(subgraph_insertion_fn):
5356
"""Function for use if substitution is disabled"""
5457
return subgraph_insertion_fn
5558

56-
return register_substitution if enabled else disable_substitution
59+
return enable_substitution if enabled else disable_substitution
5760

5861

59-
def pre_aot_module_replacement(gm: torch.fx.GraphModule):
60-
"""Perform module-level graph replacement prior to AOT tracing
62+
def pre_aot_substitutions(gm: torch.fx.GraphModule):
63+
"""Perform graph substitutions prior to AOT tracing
6164
6265
Args:
63-
gm: FX GraphModule to perform module replacement on
66+
gm: FX GraphModule to perform substitution on
6467
Returns:
6568
torch.fx.GraphModule
6669
@@ -71,48 +74,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
7174

7275
# Iterate over graph nodes, extracting module calls, to check for interceptions
7376
for n in gm.graph.nodes:
77+
exists_in_registry = False
78+
to_replace = None
79+
7480
if n.op == "call_module":
75-
# Extract submodule from graph
81+
# Extract submodule from graph, validate in registry
7682
submodule = gm.get_submodule(n.target)
77-
78-
# If submodule is a member of the substitution registry, replace it
79-
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY:
80-
81-
try:
82-
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)]
83-
op, insertion_fn = (
84-
replacement.new_operator,
85-
replacement.subgraph_insertion_fn,
86-
)
87-
logger.debug(
88-
f"Replacing module of type {type(submodule)} with {op}"
83+
to_replace = type(submodule)
84+
exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
85+
elif n.op == "call_function":
86+
# Extract function from graph, validate in registry
87+
to_replace = n.target
88+
exists_in_registry = n.target in SUBSTITUTION_REGISTRY
89+
90+
# If submodule/function is a member of the substitution registry, replace it
91+
if exists_in_registry:
92+
try:
93+
replacement = SUBSTITUTION_REGISTRY[to_replace]
94+
op, insertion_fn = (
95+
replacement.new_operator,
96+
replacement.subgraph_insertion_fn,
97+
)
98+
logger.debug(f"Replacing node of type {to_replace} with {op}")
99+
100+
# Insert new node prior to older node
101+
with gm.graph.inserting_before(n):
102+
new_node = insertion_fn(
103+
gm, n, submodule if n.op == "call_module" else None
89104
)
90105

91-
# Insert new node prior to older node
92-
with gm.graph.inserting_before(n):
93-
new_node = insertion_fn(gm, submodule, n)
94-
95-
# If submodule is not a native torch.nn module, it must be manually excluded
96-
# from Dynamo tracing
97-
if not type(submodule).__module__.startswith("torch.nn"):
98-
torch._dynamo.allowed_functions._allowed_function_ids.add(
99-
id(type(submodule))
100-
)
101-
102-
# Replace all original node uses and clean up graph
103-
n.replace_all_uses_with(new_node)
104-
gm.graph.eliminate_dead_code()
105-
gm.graph.lint()
106-
gm.recompile()
107-
108-
# A module replacement can fail in the event that the specific instance of the submodule cannot
109-
# be replaced
110-
except Exception:
111-
logger.debug(
112-
f"Encountered error while replacing {type(submodule)}",
113-
exc_info=True,
106+
# If submodule is not a native torch.nn module, it must be manually excluded
107+
# from Dynamo tracing
108+
if n.op == "call_module" and not type(submodule).__module__.startswith(
109+
"torch.nn"
110+
):
111+
torch._dynamo.allowed_functions._allowed_function_ids.add(
112+
id(to_replace)
114113
)
115-
continue
114+
115+
# Replace all original node uses and clean up graph
116+
n.replace_all_uses_with(new_node)
117+
gm.graph.eliminate_dead_code()
118+
gm.graph.lint()
119+
gm.recompile()
120+
121+
# A replacement can fail in the event that the specific instance of the submodule/function
122+
# cannot be replaced
123+
except Exception:
124+
logger.debug(
125+
f"Encountered error while replacing {to_replace}",
126+
exc_info=True,
127+
)
128+
continue
116129

117130
# Perform cleanup and recompilation before returning module
118131
gm.graph.eliminate_dead_code()
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .maxpool1d import *
2+
from .einsum import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Dict, Tuple
2+
import torch
3+
from torch._custom_op.impl import custom_op
4+
from torch.fx.node import Argument, Target
5+
6+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
from torch_tensorrt.dynamo.backend.lowering import register_substitution
11+
12+
13+
@custom_op(
14+
qualname="tensorrt::einsum",
15+
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
16+
)
17+
def einsum(equation, tensors):
18+
# Defines operator schema, name, namespace, and function header
19+
...
20+
21+
22+
@einsum.impl("cpu")
23+
@einsum.impl("cuda")
24+
def einsum_generic(
25+
*args,
26+
**kwargs,
27+
):
28+
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
29+
return torch.einsum(
30+
*args,
31+
**kwargs,
32+
)
33+
34+
35+
@tensorrt_converter(torch.ops.tensorrt.einsum.default)
36+
def aten_ops_einsum(
37+
network: TRTNetwork,
38+
target: Target,
39+
args: Tuple[Argument, ...],
40+
kwargs: Dict[str, Argument],
41+
name: str,
42+
) -> TRTTensor:
43+
# Defines converter replacing the default operator for this function
44+
for input_trt in args[1]:
45+
if not isinstance(input_trt, TRTTensor):
46+
raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}")
47+
48+
einsum_layer = network.add_einsum(inputs=args[1], equation=args[0])
49+
50+
set_layer_name(einsum_layer, target, name)
51+
return einsum_layer.get_output(0)
52+
53+
54+
@register_substitution(torch.einsum, torch.ops.tensorrt.einsum)
55+
def einsum_insertion_fn(
56+
gm: torch.fx.GraphModule,
57+
node: torch.fx.Node,
58+
_unused: None = None,
59+
) -> torch.fx.Node:
60+
equation = node.args[0]
61+
62+
# Ensure inputs is a list of (Tensor) arguments
63+
if isinstance(node.args[1], (tuple, list)):
64+
inputs = node.args[1]
65+
else:
66+
inputs = node.args[1:]
67+
68+
assert (
69+
1 <= len(inputs) <= 2
70+
), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors"
71+
72+
# Ensure the input is formatted as an equation and
73+
new_node = gm.graph.call_function(
74+
torch.ops.tensorrt.einsum,
75+
args=(equation, inputs),
76+
kwargs=node.kwargs,
77+
)
78+
79+
return new_node

Diff for: py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py renamed to py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_tensorrt.fx.converters import acc_ops_converters
88
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
99

10-
from torch_tensorrt.dynamo.backend.lowering import module_substitution
10+
from torch_tensorrt.dynamo.backend.lowering import register_substitution
1111

1212

1313
# This file serves as an example and a tutorial for excluding custom modules from
@@ -70,9 +70,11 @@ def maxpool1d_generic(
7070
# "bias": bias,
7171
# ...
7272
#
73-
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
73+
@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
7474
def maxpool1d_insertion_fn(
75-
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
75+
gm: torch.fx.GraphModule,
76+
node: torch.fx.Node,
77+
submodule: torch.nn.Module,
7678
) -> torch.fx.Node:
7779
# Defines insertion function for new node
7880
new_node = gm.graph.call_function(

Diff for: py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py

+46
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,51 @@ def forward(self, x):
5151
)
5252

5353

54+
class TestEinsum(TestCase):
55+
def test_pre_aot_lowering_einsum(self):
56+
class Einsum(torch.nn.Module):
57+
def forward(self, x, y):
58+
return torch.einsum("ij,ji->ij", x, y)
59+
60+
# Operations expected to be included in the traced graph after decompositions
61+
expected_ops = {torch.ops.tensorrt.einsum.default}
62+
63+
inputs = [
64+
torch.rand(
65+
16,
66+
16,
67+
).cuda(),
68+
torch.rand(
69+
16,
70+
16,
71+
).cuda(),
72+
]
73+
74+
fx_graph = torch.fx.symbolic_trace(Einsum())
75+
_, expected_ops_unseen = lower_graph_testing(
76+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
77+
)
78+
79+
self.assertEquals(
80+
len(expected_ops_unseen),
81+
0,
82+
f"The following expected ops were not encountered: {expected_ops_unseen}",
83+
)
84+
85+
torch._dynamo.reset()
86+
87+
# Validate that the results between Torch and Torch-TRT are similar
88+
optimized_model = compile(
89+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
90+
)
91+
optimized_model_results = optimized_model(*inputs).detach().cpu()
92+
torch_model_results = fx_graph(*inputs).detach().cpu()
93+
94+
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
95+
self.assertAlmostEqual(
96+
max_diff, 0, f"Einsum TRT outputs don't match with the original model."
97+
)
98+
99+
54100
if __name__ == "__main__":
55101
run_tests()

0 commit comments

Comments
 (0)