Skip to content

Commit 54843d3

Browse files
committed
fix: Address review comments
- Fix typing issues, add depedencies to `setup.py`, add qualified name checking for module registry - Add detailed tutorial descriptions to sample module substitution with step-by-step detailed instructions for creating a new module substitution - Update `custom_op` for new Torch schema
1 parent b110e60 commit 54843d3

File tree

6 files changed

+84
-31
lines changed

6 files changed

+84
-31
lines changed

Diff for: .circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ commands:
258258
name: Set up python environment
259259
command: |
260260
pip3 install --upgrade pip
261-
pip3 install wheel setuptools pyyaml
261+
pip3 install wheel setuptools
262262
pip3 install nvidia-pyindex
263263
pip3 install tabulate
264264
pip3 install tensorrt==<< parameters.trt-version-long >> nvidia-cudnn-cu11==<< parameters.cudnn-version-long >>

Diff for: py/torch_tensorrt/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _find_lib(name, paths):
9494

9595
from torch_tensorrt import fx
9696

97-
if version.parse(torch.__version__) >= version.parse("2.dev"):
97+
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
9898
from torch_tensorrt import dynamo
9999
from torch_tensorrt.dynamo import backend
100100

Diff for: py/torch_tensorrt/dynamo/backend/backends.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def aot_torch_tensorrt_aten_backend(
6969

7070
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))
7171

72-
# Enable Pre-AOT Lowering for Module-Level Replacement
72+
# Perform Pre-AOT Lowering for Module-Level Replacement
7373
gm = pre_aot_module_replacement(gm)
7474

7575
logger.debug("Post-module replacement graph:\n" + str(gm.graph))

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
logger = logging.getLogger(__name__)
1717

1818
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19-
"torch.ops." + str(module.new_operator)
19+
_get_qualified_name(module.new_operator)
2020
for module in MODULE_SUBSTITUTION_REGISTRY.values()
2121
)
2222

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Callable, Dict
2+
from typing import Any, Callable, Dict, Type
33
import torch
44
import logging
55

@@ -23,11 +23,11 @@ class ModuleReplacement:
2323

2424

2525
# Dictionary mapping module to ModuleReplacement instance
26-
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict()
26+
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict()
2727

2828

2929
def module_substitution(
30-
module_to_replace: torch.nn.Module,
30+
module_to_replace: Type[torch.nn.Module],
3131
new_operator: torch._ops.OpOverload,
3232
enabled: bool = True,
3333
) -> Callable[[Any], Any]:
@@ -102,6 +102,7 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
102102
# Replace all original node uses and clean up graph
103103
n.replace_all_uses_with(new_node)
104104
gm.graph.eliminate_dead_code()
105+
gm.graph.lint()
105106
gm.recompile()
106107

107108
# A module replacement can fail in the event that the specific instance of the submodule cannot
@@ -115,5 +116,6 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
115116

116117
# Perform cleanup and recompilation before returning module
117118
gm.graph.eliminate_dead_code()
119+
gm.graph.lint()
118120
gm.recompile()
119121
return gm
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, Tuple
22
import torch
3-
from torch._custom_op import custom_op
3+
from torch._custom_op.impl import custom_op
44
from torch.fx.node import Argument, Target
55

66
from torch_tensorrt.fx.converter_registry import tensorrt_converter
@@ -10,30 +10,94 @@
1010
from torch_tensorrt.dynamo.backend.lowering import module_substitution
1111

1212

13+
# This file serves as an example and a tutorial for excluding custom modules from
14+
# torch.compile tracing. Each required step is labeled with a number indicating the
15+
# preferable implementation order.
16+
17+
18+
# 1. The Placeholder
19+
#
20+
# Specify the schema and namespace of the operator, as well as a placeholder function
21+
# representing the schema. The schema should be in torch JIT syntax, indicating input and output
22+
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
23+
# Then, create a placeholder function with no operations, but having the same schema and naming as that
24+
# used in the decorator
1325
@custom_op(
14-
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
15-
ns="tensorrt",
26+
qualname="tensorrt::maxpool1d",
27+
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
1628
)
17-
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
29+
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode):
1830
# Defines operator schema, name, namespace, and function header
1931
...
2032

2133

34+
# 2. The Generic Implementation
35+
#
36+
# Define the default implementation of the operator in torch syntax. This is used for autograd
37+
# and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace
38+
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
39+
# implementation here. Note that the function header to the generic function can have specific arguments
40+
# as in the above placeholder
2241
@maxpool1d.impl("cpu")
2342
@maxpool1d.impl("cuda")
2443
def maxpool1d_generic(
2544
*args,
2645
**kwargs,
2746
):
28-
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
47+
# Defines an implementation for AOT Autograd to use for shape analysis/propagation
2948
return torch.nn.functional.max_pool1d(
3049
*args,
3150
**kwargs,
3251
)
3352

3453

54+
# 3. The Module Substitution Function
55+
#
56+
# Define a function which can intercept a node of the kind to be replaced, extract
57+
# the relevant data from that node/submodule, and then re-package the information
58+
# for use by an accelerated implementation (to be implemented in step 4). This function
59+
# should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d).
60+
# It should refactor the args and kwargs as is needed by the accelerated implementation.
61+
#
62+
# If the submodule has weights or other Tensor fields which the accelerated implementation
63+
# needs, the function should insert the necessary nodes to access those weights. For example,
64+
# if the weight Tensor of a submodule is needed, one could write:
65+
#
66+
# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor)
67+
# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor)
68+
# ...
69+
# kwargs={"weight": weights,
70+
# "bias": bias,
71+
# ...
72+
#
73+
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
74+
def maxpool1d_insertion_fn(
75+
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
76+
) -> torch.fx.Node:
77+
# Defines insertion function for new node
78+
new_node = gm.graph.call_function(
79+
torch.ops.tensorrt.maxpool1d,
80+
args=node.args,
81+
kwargs={
82+
"kernel_size": submodule.kernel_size,
83+
"stride": submodule.stride,
84+
"padding": submodule.padding,
85+
"dilation": submodule.dilation,
86+
"ceil_mode": submodule.ceil_mode,
87+
},
88+
)
89+
90+
return new_node
91+
92+
93+
# 4. The Accelerated Implementation
94+
#
95+
# Define an accelerated implementation of the operator, and register it as necessary.
96+
# This accelerated implementation should consume the args/kwargs specified in step 3.
97+
# One should expect that torch.compile will compress all kwargs into the args field in
98+
# the order specified in the schema written in step 1.
3599
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
36-
def aten_ops_maxpool1d(
100+
def tensorrt_maxpool1d(
37101
network: TRTNetwork,
38102
target: Target,
39103
args: Tuple[Argument, ...],
@@ -55,21 +119,8 @@ def aten_ops_maxpool1d(
55119
)
56120

57121

58-
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
59-
def maxpool1d_insertion_fn(
60-
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
61-
) -> torch.fx.Node:
62-
# Defines insertion function for new node
63-
new_node = gm.graph.call_function(
64-
torch.ops.tensorrt.maxpool1d,
65-
args=node.args,
66-
kwargs={
67-
"kernel_size": submodule.kernel_size,
68-
"stride": submodule.stride,
69-
"padding": submodule.padding,
70-
"dilation": submodule.dilation,
71-
"ceil_mode": submodule.ceil_mode,
72-
},
73-
)
74-
75-
return new_node
122+
# 5. Add Imports
123+
#
124+
# Add your accelerated module file to the __init__.py in this directory, to ensure
125+
# all registrations are run. For instance, if the new module file is called new_mod.py,
126+
# one should add `from .new_mod import *` to the __init__.py

0 commit comments

Comments
 (0)