15
15
16
16
class FuseConsecutiveTranspose (ExportPass ):
17
17
"""
18
- This pass fuses consecutive transpose / permute into one to reduce runtime
19
- overhead
18
+ This pass fuses consecutive transpose / permute into one or none to reduce runtime
19
+ overhead.
20
+ To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node
21
+ by cloning transpose.
22
+ Example:
23
+ Before clone transpose:
24
+ relu -> permute1 ─> permute2
25
+ |──────> permute3
26
+
27
+ After clone transpose:
28
+ relu ─> permute1 ──────> permute2
29
+ |───> permute4(new) ─> permute3
20
30
"""
21
31
22
32
def __init__ (self ):
@@ -27,54 +37,81 @@ def __init__(self):
27
37
self .visited = set ()
28
38
self .nodes = []
29
39
40
+ def _clone_transpose (
41
+ self , graph_module : torch .fx .GraphModule
42
+ ) -> torch .fx .GraphModule :
43
+ graph = graph_module .graph
44
+ for n in graph_module .graph .nodes :
45
+ if n .target in self .op_map :
46
+ users = [user for user in list (n .users ) if user .target in self .op_map ]
47
+ if len (users ) > 1 :
48
+ for i in range (1 , len (users )):
49
+ with graph .inserting_after (n ):
50
+ clone_permute_node = graph .create_node (
51
+ "call_function" ,
52
+ exir_ops .edge .aten .permute_copy .default ,
53
+ (n .args [0 ], n .args [1 ]),
54
+ )
55
+ clone_permute_node .meta = n .meta
56
+ users [i ].replace_input_with (n , clone_permute_node )
57
+
58
+ def _is_dispensable (self , axis_order ):
59
+ for index , value in enumerate (axis_order ):
60
+ if index != value :
61
+ return False
62
+ return True
63
+
30
64
def _traverse (self , node ):
31
65
if node in self .visited or node .target not in self .op_map :
32
66
return
33
67
34
68
self .nodes .append (node )
35
69
self .visited .add (node )
36
70
next_users = [n for n in list (node .users ) if n .target in self .op_map ]
71
+
72
+ assert (
73
+ len (next_users ) <= 1
74
+ ), "Each permute node should have at most 1 permute output node after _clone_transpose"
37
75
if not next_users :
38
76
return
39
-
40
- if len (next_users ) == 1 :
41
- self ._traverse (list (node .users )[0 ])
42
77
else :
43
- raise NotImplementedError (
44
- f"Check the node { node } , wich encounter mutilple permute output case"
45
- )
78
+ self ._traverse (list (node .users )[0 ])
46
79
47
80
def _fuse (self , graph_module : torch .fx .GraphModule ) -> torch .fx .GraphModule :
48
81
graph = graph_module .graph
49
82
for n in graph_module .graph .nodes :
50
83
self ._traverse (n )
51
84
if len (self .nodes ) > 1 :
52
- permute_order = []
53
85
input_node , output_node = self .nodes [0 ].args [0 ], self .nodes [- 1 ]
54
86
input_shape = input_node .meta ["val" ].shape
55
87
axis_order = torch .arange (len (input_shape )).tolist ()
56
88
for node in self .nodes :
57
- permute_order .append (node .args [1 ])
58
89
axis_order = [axis_order [i ] for i in node .args [1 ]]
59
- with graph .inserting_after (input_node ):
60
- permute_op = exir_ops .edge .aten .permute_copy .default
61
- permute_node = graph .create_node (
62
- "call_function" , permute_op , (input_node , axis_order )
63
- )
64
- users = output_node .users .copy ()
65
- for user in users :
66
- user .replace_input_with (output_node , permute_node )
67
-
68
- # copy metadata
69
- permute_node .meta = output_node .meta
70
- # Without "qnn_permute", we might obtain wrong input shape
71
- if [pn .meta .get (QCOM_INSERTED_PERMUTE ) for pn in self .nodes ]:
72
- permute_node .meta [QCOM_INSERTED_PERMUTE ] = True
90
+ # If axis order is just [0,1,2,3], we ignore permute node
91
+ if self ._is_dispensable (axis_order ):
92
+ for user in output_node .users .copy ():
93
+ user .replace_input_with (output_node , n .args [0 ])
94
+ else :
95
+ with graph .inserting_after (input_node ):
96
+ permute_op = exir_ops .edge .aten .permute_copy .default
97
+ permute_node = graph .create_node (
98
+ "call_function" , permute_op , (input_node , axis_order )
99
+ )
100
+ users = output_node .users .copy ()
101
+ for user in users :
102
+ user .replace_input_with (output_node , permute_node )
103
+
104
+ # copy metadata
105
+ permute_node .meta = output_node .meta
106
+ # Without "qnn_permute", we might obtain wrong input shape
107
+ if [pn .meta .get (QCOM_INSERTED_PERMUTE ) for pn in self .nodes ]:
108
+ permute_node .meta [QCOM_INSERTED_PERMUTE ] = True
73
109
74
110
# clear current stack
75
111
self .nodes = []
76
112
77
113
def call (self , graph_module : torch .fx .GraphModule ):
114
+ self ._clone_transpose (graph_module )
78
115
self ._fuse (graph_module )
79
116
graph_module .recompile ()
80
117
dead_code_elimination_pass (graph_module )
0 commit comments