1
1
from dataclasses import dataclass
2
- from typing import Any , Callable , Dict , Type
2
+ from typing import Any , Callable , Dict , Optional , Type , Union
3
3
import torch
4
4
import logging
5
5
8
8
9
9
10
10
@dataclass (frozen = True )
11
- class ModuleReplacement :
11
+ class Substitution :
12
12
"""Class to store key functionality for module replacement"""
13
13
14
14
# torch.ops.___ name for replacement function for module
15
15
new_operator : torch ._ops .OpOverload
16
16
17
- # Function taking a containing graph, a submodule, and a 'call_module' node and returning
18
- # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
17
+ # Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
18
+ # and returning a replacement node, with type 'call_function', or raising an Error if
19
+ # incompatibility is detected
19
20
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
20
21
subgraph_insertion_fn : Callable [
21
- [torch .fx .GraphModule , torch .nn . Module , torch .fx . Node ], torch .fx .Node
22
+ [torch .fx .GraphModule , torch .fx . Node , Optional [ torch .nn . Module ] ], torch .fx .Node
22
23
]
23
24
24
25
25
- # Dictionary mapping module to ModuleReplacement instance
26
- MODULE_SUBSTITUTION_REGISTRY : Dict [Type [torch .nn .Module ], ModuleReplacement ] = dict ()
26
+ # Dictionary mapping module to Substitution instance
27
+ SUBSTITUTION_REGISTRY : Dict [
28
+ Union [Type [torch .nn .Module ], Callable ], Substitution
29
+ ] = dict ()
27
30
28
31
29
- def module_substitution (
30
- module_to_replace : Type [torch .nn .Module ],
32
+ def register_substitution (
33
+ module_or_function_to_replace : Union [ Type [torch .nn .Module ], Callable ],
31
34
new_operator : torch ._ops .OpOverload ,
32
35
enabled : bool = True ,
33
36
) -> Callable [[Any ], Any ]:
34
37
"""Decorator to register subgraph insertion functions
35
38
36
39
Args:
37
- module_to_replace : nn.Module to replace
40
+ module_or_function_to_replace : nn.Module or node target Callable to replace
38
41
new_operator: Custom torch operator to replace with
39
42
enabled: Whether the substitution is enabled or disabled
40
43
Returns:
41
44
torch.fx.GraphModule
42
45
"""
43
46
44
- def register_substitution (subgraph_insertion_fn ):
47
+ def enable_substitution (subgraph_insertion_fn ):
45
48
"""Function for use if substitution is enabled"""
46
- module_replacement = ModuleReplacement (
49
+ replacement = Substitution (
47
50
new_operator = new_operator , subgraph_insertion_fn = subgraph_insertion_fn
48
51
)
49
- MODULE_SUBSTITUTION_REGISTRY [ module_to_replace ] = module_replacement
52
+ SUBSTITUTION_REGISTRY [ module_or_function_to_replace ] = replacement
50
53
return subgraph_insertion_fn
51
54
52
55
def disable_substitution (subgraph_insertion_fn ):
53
56
"""Function for use if substitution is disabled"""
54
57
return subgraph_insertion_fn
55
58
56
- return register_substitution if enabled else disable_substitution
59
+ return enable_substitution if enabled else disable_substitution
57
60
58
61
59
- def pre_aot_module_replacement (gm : torch .fx .GraphModule ):
60
- """Perform module-level graph replacement prior to AOT tracing
62
+ def pre_aot_substitutions (gm : torch .fx .GraphModule ):
63
+ """Perform graph substitutions prior to AOT tracing
61
64
62
65
Args:
63
- gm: FX GraphModule to perform module replacement on
66
+ gm: FX GraphModule to perform substitution on
64
67
Returns:
65
68
torch.fx.GraphModule
66
69
@@ -73,48 +76,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
73
76
74
77
# Iterate over graph nodes, extracting module calls, to check for interceptions
75
78
for n in gm .graph .nodes :
79
+ exists_in_registry = False
80
+ to_replace = None
81
+
76
82
if n .op == "call_module" :
77
- # Extract submodule from graph
83
+ # Extract submodule from graph, validate in registry
78
84
submodule = gm .get_submodule (n .target )
79
-
80
- # If submodule is a member of the substitution registry, replace it
81
- if type (submodule ) in MODULE_SUBSTITUTION_REGISTRY :
82
-
83
- try :
84
- replacement = MODULE_SUBSTITUTION_REGISTRY [type (submodule )]
85
- op , insertion_fn = (
86
- replacement .new_operator ,
87
- replacement .subgraph_insertion_fn ,
88
- )
89
- logger .debug (
90
- f"Replacing module of type { type (submodule )} with { op } "
85
+ to_replace = type (submodule )
86
+ exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
87
+ elif n .op == "call_function" :
88
+ # Extract function from graph, validate in registry
89
+ to_replace = n .target
90
+ exists_in_registry = n .target in SUBSTITUTION_REGISTRY
91
+
92
+ # If submodule/function is a member of the substitution registry, replace it
93
+ if exists_in_registry :
94
+ try :
95
+ replacement = SUBSTITUTION_REGISTRY [to_replace ]
96
+ op , insertion_fn = (
97
+ replacement .new_operator ,
98
+ replacement .subgraph_insertion_fn ,
99
+ )
100
+ logger .debug (f"Replacing node of type { to_replace } with { op } " )
101
+
102
+ # Insert new node prior to older node
103
+ with gm .graph .inserting_before (n ):
104
+ new_node = insertion_fn (
105
+ gm , n , submodule if n .op == "call_module" else None
91
106
)
92
107
93
- # Insert new node prior to older node
94
- with gm .graph .inserting_before (n ):
95
- new_node = insertion_fn (gm , submodule , n )
96
-
97
- # If submodule is not a native torch.nn module, it must be manually excluded
98
- # from Dynamo tracing
99
- if not type (submodule ).__module__ .startswith ("torch.nn" ):
100
- torch ._dynamo .allowed_functions ._allowed_function_ids .add (
101
- id (type (submodule ))
102
- )
103
-
104
- # Replace all original node uses and clean up graph
105
- n .replace_all_uses_with (new_node )
106
- gm .graph .eliminate_dead_code ()
107
- gm .graph .lint ()
108
- gm .recompile ()
109
-
110
- # A module replacement can fail in the event that the specific instance of the submodule cannot
111
- # be replaced
112
- except Exception :
113
- logger .debug (
114
- f"Encountered error while replacing { type (submodule )} " ,
115
- exc_info = True ,
108
+ # If submodule is not a native torch.nn module, it must be manually excluded
109
+ # from Dynamo tracing
110
+ if n .op == "call_module" and not type (submodule ).__module__ .startswith (
111
+ "torch.nn"
112
+ ):
113
+ torch ._dynamo .allowed_functions ._allowed_function_ids .add (
114
+ id (to_replace )
116
115
)
117
- continue
116
+
117
+ # Replace all original node uses and clean up graph
118
+ n .replace_all_uses_with (new_node )
119
+ gm .graph .eliminate_dead_code ()
120
+ gm .graph .lint ()
121
+ gm .recompile ()
122
+
123
+ # A replacement can fail in the event that the specific instance of the submodule/function
124
+ # cannot be replaced
125
+ except Exception :
126
+ logger .debug (
127
+ f"Encountered error while replacing { to_replace } " ,
128
+ exc_info = True ,
129
+ )
130
+ continue
118
131
119
132
# Perform cleanup and recompilation before returning module
120
133
gm .graph .eliminate_dead_code ()
0 commit comments