1
1
from typing import Dict , Tuple
2
2
import torch
3
- from torch ._custom_op import custom_op
3
+ from torch ._custom_op . impl import custom_op
4
4
from torch .fx .node import Argument , Target
5
5
6
6
from torch_tensorrt .fx .converter_registry import tensorrt_converter
10
10
from torch_tensorrt .dynamo .backend .lowering import module_substitution
11
11
12
12
13
+ # This file serves as an example and a tutorial for excluding custom modules from
14
+ # torch.compile tracing. Each required step is labeled with a number indicating the
15
+ # preferable implementation order.
16
+
17
+
18
+ # 1. The Placeholder
19
+ #
20
+ # Specify the schema and namespace of the operator, as well as a placeholder function
21
+ # representing the schema. The schema should be in torch JIT syntax, indicating input and output
22
+ # types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
23
+ # Then, create a placeholder function with no operations, but having the same schema and naming as that
24
+ # used in the decorator
13
25
@custom_op (
14
- "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor " ,
15
- ns = "tensorrt " ,
26
+ qualname = "tensorrt::maxpool1d " ,
27
+ manual_schema = "(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor " ,
16
28
)
17
- def maxpool1d (x , kernel_size , stride = None , padding = 0 , dilation = 1 , ceil_mode = False ):
29
+ def maxpool1d (x , kernel_size , stride , padding , dilation , ceil_mode ):
18
30
# Defines operator schema, name, namespace, and function header
19
31
...
20
32
21
33
34
+ # 2. The Generic Implementation
35
+ #
36
+ # Define the default implementation of the operator in torch syntax. This is used for autograd
37
+ # and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace
38
+ # is desirable. If the operator to replace is a custom module you've written, then add its Torch
39
+ # implementation here. Note that the function header to the generic function can have specific arguments
40
+ # as in the above placeholder
22
41
@maxpool1d .impl ("cpu" )
23
42
@maxpool1d .impl ("cuda" )
24
43
def maxpool1d_generic (
25
44
* args ,
26
45
** kwargs ,
27
46
):
28
- # Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
47
+ # Defines an implementation for AOT Autograd to use for shape analysis/propagation
29
48
return torch .nn .functional .max_pool1d (
30
49
* args ,
31
50
** kwargs ,
32
51
)
33
52
34
53
54
+ # 3. The Module Substitution Function
55
+ #
56
+ # Define a function which can intercept a node of the kind to be replaced, extract
57
+ # the relevant data from that node/submodule, and then re-package the information
58
+ # for use by an accelerated implementation (to be implemented in step 4). This function
59
+ # should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d).
60
+ # It should refactor the args and kwargs as is needed by the accelerated implementation.
61
+ #
62
+ # If the submodule has weights or other Tensor fields which the accelerated implementation
63
+ # needs, the function should insert the necessary nodes to access those weights. For example,
64
+ # if the weight Tensor of a submodule is needed, one could write:
65
+ #
66
+ # weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor)
67
+ # bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor)
68
+ # ...
69
+ # kwargs={"weight": weights,
70
+ # "bias": bias,
71
+ # ...
72
+ #
73
+ @module_substitution (torch .nn .MaxPool1d , torch .ops .tensorrt .maxpool1d )
74
+ def maxpool1d_insertion_fn (
75
+ gm : torch .fx .GraphModule , submodule : torch .nn .Module , node : torch .fx .Node
76
+ ) -> torch .fx .Node :
77
+ # Defines insertion function for new node
78
+ new_node = gm .graph .call_function (
79
+ torch .ops .tensorrt .maxpool1d ,
80
+ args = node .args ,
81
+ kwargs = {
82
+ "kernel_size" : submodule .kernel_size ,
83
+ "stride" : submodule .stride ,
84
+ "padding" : submodule .padding ,
85
+ "dilation" : submodule .dilation ,
86
+ "ceil_mode" : submodule .ceil_mode ,
87
+ },
88
+ )
89
+
90
+ return new_node
91
+
92
+
93
+ # 4. The Accelerated Implementation
94
+ #
95
+ # Define an accelerated implementation of the operator, and register it as necessary.
96
+ # This accelerated implementation should consume the args/kwargs specified in step 3.
97
+ # One should expect that torch.compile will compress all kwargs into the args field in
98
+ # the order specified in the schema written in step 1.
35
99
@tensorrt_converter (torch .ops .tensorrt .maxpool1d .default )
36
- def aten_ops_maxpool1d (
100
+ def tensorrt_maxpool1d (
37
101
network : TRTNetwork ,
38
102
target : Target ,
39
103
args : Tuple [Argument , ...],
@@ -55,21 +119,8 @@ def aten_ops_maxpool1d(
55
119
)
56
120
57
121
58
- @module_substitution (torch .nn .MaxPool1d , torch .ops .tensorrt .maxpool1d )
59
- def maxpool1d_insertion_fn (
60
- gm : torch .fx .GraphModule , submodule : torch .nn .Module , node : torch .fx .Node
61
- ) -> torch .fx .Node :
62
- # Defines insertion function for new node
63
- new_node = gm .graph .call_function (
64
- torch .ops .tensorrt .maxpool1d ,
65
- args = node .args ,
66
- kwargs = {
67
- "kernel_size" : submodule .kernel_size ,
68
- "stride" : submodule .stride ,
69
- "padding" : submodule .padding ,
70
- "dilation" : submodule .dilation ,
71
- "ceil_mode" : submodule .ceil_mode ,
72
- },
73
- )
74
-
75
- return new_node
122
+ # 5. Add Imports
123
+ #
124
+ # Add your accelerated module file to the __init__.py in this directory, to ensure
125
+ # all registrations are run. For instance, if the new module file is called new_mod.py,
126
+ # one should add `from .new_mod import *` to the __init__.py
0 commit comments