1
1
from dataclasses import dataclass
2
- from typing import Any , Callable , Dict , Type
2
+ from typing import Any , Callable , Dict , Optional , Type , Union
3
3
import torch
4
4
import logging
5
5
8
8
9
9
10
10
@dataclass (frozen = True )
11
- class ModuleReplacement :
11
+ class Substitution :
12
12
"""Class to store key functionality for module replacement"""
13
13
14
14
# torch.ops.___ name for replacement function for module
15
15
new_operator : torch ._ops .OpOverload
16
16
17
- # Function taking a containing graph, a submodule, and a 'call_module' node and returning
18
- # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
17
+ # Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
18
+ # and returning a replacement node, with type 'call_function', or raising an Error if
19
+ # incompatibility is detected
19
20
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
20
21
subgraph_insertion_fn : Callable [
21
- [torch .fx .GraphModule , torch .nn . Module , torch .fx . Node ], torch .fx .Node
22
+ [torch .fx .GraphModule , torch .fx . Node , Optional [ torch .nn . Module ] ], torch .fx .Node
22
23
]
23
24
24
25
25
- # Dictionary mapping module to ModuleReplacement instance
26
- MODULE_SUBSTITUTION_REGISTRY : Dict [Type [torch .nn .Module ], ModuleReplacement ] = dict ()
26
+ # Dictionary mapping module to Substitution instance
27
+ SUBSTITUTION_REGISTRY : Dict [
28
+ Union [Type [torch .nn .Module ], Callable ], Substitution
29
+ ] = dict ()
27
30
28
31
29
- def module_substitution (
30
- module_to_replace : Type [torch .nn .Module ],
32
+ def register_substitution (
33
+ module_or_function_to_replace : Union [ Type [torch .nn .Module ], Callable ],
31
34
new_operator : torch ._ops .OpOverload ,
32
35
enabled : bool = True ,
33
36
) -> Callable [[Any ], Any ]:
34
37
"""Decorator to register subgraph insertion functions
35
38
36
39
Args:
37
- module_to_replace : nn.Module to replace
40
+ module_or_function_to_replace : nn.Module or node target Callable to replace
38
41
new_operator: Custom torch operator to replace with
39
42
enabled: Whether the substitution is enabled or disabled
40
43
Returns:
41
44
torch.fx.GraphModule
42
45
"""
43
46
44
- def register_substitution (subgraph_insertion_fn ):
47
+ def enable_substitution (subgraph_insertion_fn ):
45
48
"""Function for use if substitution is enabled"""
46
- module_replacement = ModuleReplacement (
49
+ replacement = Substitution (
47
50
new_operator = new_operator , subgraph_insertion_fn = subgraph_insertion_fn
48
51
)
49
- MODULE_SUBSTITUTION_REGISTRY [ module_to_replace ] = module_replacement
52
+ SUBSTITUTION_REGISTRY [ module_or_function_to_replace ] = replacement
50
53
return subgraph_insertion_fn
51
54
52
55
def disable_substitution (subgraph_insertion_fn ):
53
56
"""Function for use if substitution is disabled"""
54
57
return subgraph_insertion_fn
55
58
56
- return register_substitution if enabled else disable_substitution
59
+ return enable_substitution if enabled else disable_substitution
57
60
58
61
59
- def pre_aot_module_replacement (gm : torch .fx .GraphModule ):
60
- """Perform module-level graph replacement prior to AOT tracing
62
+ def pre_aot_substitutions (gm : torch .fx .GraphModule ):
63
+ """Perform graph substitutions prior to AOT tracing
61
64
62
65
Args:
63
- gm: FX GraphModule to perform module replacement on
66
+ gm: FX GraphModule to perform substitution on
64
67
Returns:
65
68
torch.fx.GraphModule
66
69
@@ -71,48 +74,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
71
74
72
75
# Iterate over graph nodes, extracting module calls, to check for interceptions
73
76
for n in gm .graph .nodes :
77
+ exists_in_registry = False
78
+ to_replace = None
79
+
74
80
if n .op == "call_module" :
75
- # Extract submodule from graph
81
+ # Extract submodule from graph, validate in registry
76
82
submodule = gm .get_submodule (n .target )
77
-
78
- # If submodule is a member of the substitution registry, replace it
79
- if type (submodule ) in MODULE_SUBSTITUTION_REGISTRY :
80
-
81
- try :
82
- replacement = MODULE_SUBSTITUTION_REGISTRY [type (submodule )]
83
- op , insertion_fn = (
84
- replacement .new_operator ,
85
- replacement .subgraph_insertion_fn ,
86
- )
87
- logger .debug (
88
- f"Replacing module of type { type (submodule )} with { op } "
83
+ to_replace = type (submodule )
84
+ exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
85
+ elif n .op == "call_function" :
86
+ # Extract function from graph, validate in registry
87
+ to_replace = n .target
88
+ exists_in_registry = n .target in SUBSTITUTION_REGISTRY
89
+
90
+ # If submodule/function is a member of the substitution registry, replace it
91
+ if exists_in_registry :
92
+ try :
93
+ replacement = SUBSTITUTION_REGISTRY [to_replace ]
94
+ op , insertion_fn = (
95
+ replacement .new_operator ,
96
+ replacement .subgraph_insertion_fn ,
97
+ )
98
+ logger .debug (f"Replacing node of type { to_replace } with { op } " )
99
+
100
+ # Insert new node prior to older node
101
+ with gm .graph .inserting_before (n ):
102
+ new_node = insertion_fn (
103
+ gm , n , submodule if n .op == "call_module" else None
89
104
)
90
105
91
- # Insert new node prior to older node
92
- with gm .graph .inserting_before (n ):
93
- new_node = insertion_fn (gm , submodule , n )
94
-
95
- # If submodule is not a native torch.nn module, it must be manually excluded
96
- # from Dynamo tracing
97
- if not type (submodule ).__module__ .startswith ("torch.nn" ):
98
- torch ._dynamo .allowed_functions ._allowed_function_ids .add (
99
- id (type (submodule ))
100
- )
101
-
102
- # Replace all original node uses and clean up graph
103
- n .replace_all_uses_with (new_node )
104
- gm .graph .eliminate_dead_code ()
105
- gm .graph .lint ()
106
- gm .recompile ()
107
-
108
- # A module replacement can fail in the event that the specific instance of the submodule cannot
109
- # be replaced
110
- except Exception :
111
- logger .debug (
112
- f"Encountered error while replacing { type (submodule )} " ,
113
- exc_info = True ,
106
+ # If submodule is not a native torch.nn module, it must be manually excluded
107
+ # from Dynamo tracing
108
+ if n .op == "call_module" and not type (submodule ).__module__ .startswith (
109
+ "torch.nn"
110
+ ):
111
+ torch ._dynamo .allowed_functions ._allowed_function_ids .add (
112
+ id (to_replace )
114
113
)
115
- continue
114
+
115
+ # Replace all original node uses and clean up graph
116
+ n .replace_all_uses_with (new_node )
117
+ gm .graph .eliminate_dead_code ()
118
+ gm .graph .lint ()
119
+ gm .recompile ()
120
+
121
+ # A replacement can fail in the event that the specific instance of the submodule/function
122
+ # cannot be replaced
123
+ except Exception :
124
+ logger .debug (
125
+ f"Encountered error while replacing { to_replace } " ,
126
+ exc_info = True ,
127
+ )
128
+ continue
116
129
117
130
# Perform cleanup and recompilation before returning module
118
131
gm .graph .eliminate_dead_code ()
0 commit comments