Skip to content

Commit d17f45b

Browse files
authored
Qualcomm AI Engine Direct - ConvFormer Enablement (#6654)
1 parent 71c0ad8 commit d17f45b

File tree

11 files changed

+428
-48
lines changed

11 files changed

+428
-48
lines changed

backends/qualcomm/_passes/fuse_consecutive_transpose.py

+61-24
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,18 @@
1515

1616
class FuseConsecutiveTranspose(ExportPass):
1717
"""
18-
This pass fuses consecutive transpose / permute into one to reduce runtime
19-
overhead
18+
This pass fuses consecutive transpose / permute into one or none to reduce runtime
19+
overhead.
20+
To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node
21+
by cloning transpose.
22+
Example:
23+
Before clone transpose:
24+
relu -> permute1 ─> permute2
25+
|──────> permute3
26+
27+
After clone transpose:
28+
relu ─> permute1 ──────> permute2
29+
|───> permute4(new) ─> permute3
2030
"""
2131

2232
def __init__(self):
@@ -27,54 +37,81 @@ def __init__(self):
2737
self.visited = set()
2838
self.nodes = []
2939

40+
def _clone_transpose(
41+
self, graph_module: torch.fx.GraphModule
42+
) -> torch.fx.GraphModule:
43+
graph = graph_module.graph
44+
for n in graph_module.graph.nodes:
45+
if n.target in self.op_map:
46+
users = [user for user in list(n.users) if user.target in self.op_map]
47+
if len(users) > 1:
48+
for i in range(1, len(users)):
49+
with graph.inserting_after(n):
50+
clone_permute_node = graph.create_node(
51+
"call_function",
52+
exir_ops.edge.aten.permute_copy.default,
53+
(n.args[0], n.args[1]),
54+
)
55+
clone_permute_node.meta = n.meta
56+
users[i].replace_input_with(n, clone_permute_node)
57+
58+
def _is_dispensable(self, axis_order):
59+
for index, value in enumerate(axis_order):
60+
if index != value:
61+
return False
62+
return True
63+
3064
def _traverse(self, node):
3165
if node in self.visited or node.target not in self.op_map:
3266
return
3367

3468
self.nodes.append(node)
3569
self.visited.add(node)
3670
next_users = [n for n in list(node.users) if n.target in self.op_map]
71+
72+
assert (
73+
len(next_users) <= 1
74+
), "Each permute node should have at most 1 permute output node after _clone_transpose"
3775
if not next_users:
3876
return
39-
40-
if len(next_users) == 1:
41-
self._traverse(list(node.users)[0])
4277
else:
43-
raise NotImplementedError(
44-
f"Check the node {node}, wich encounter mutilple permute output case"
45-
)
78+
self._traverse(list(node.users)[0])
4679

4780
def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
4881
graph = graph_module.graph
4982
for n in graph_module.graph.nodes:
5083
self._traverse(n)
5184
if len(self.nodes) > 1:
52-
permute_order = []
5385
input_node, output_node = self.nodes[0].args[0], self.nodes[-1]
5486
input_shape = input_node.meta["val"].shape
5587
axis_order = torch.arange(len(input_shape)).tolist()
5688
for node in self.nodes:
57-
permute_order.append(node.args[1])
5889
axis_order = [axis_order[i] for i in node.args[1]]
59-
with graph.inserting_after(input_node):
60-
permute_op = exir_ops.edge.aten.permute_copy.default
61-
permute_node = graph.create_node(
62-
"call_function", permute_op, (input_node, axis_order)
63-
)
64-
users = output_node.users.copy()
65-
for user in users:
66-
user.replace_input_with(output_node, permute_node)
67-
68-
# copy metadata
69-
permute_node.meta = output_node.meta
70-
# Without "qnn_permute", we might obtain wrong input shape
71-
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
72-
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
90+
# If axis order is just [0,1,2,3], we ignore permute node
91+
if self._is_dispensable(axis_order):
92+
for user in output_node.users.copy():
93+
user.replace_input_with(output_node, n.args[0])
94+
else:
95+
with graph.inserting_after(input_node):
96+
permute_op = exir_ops.edge.aten.permute_copy.default
97+
permute_node = graph.create_node(
98+
"call_function", permute_op, (input_node, axis_order)
99+
)
100+
users = output_node.users.copy()
101+
for user in users:
102+
user.replace_input_with(output_node, permute_node)
103+
104+
# copy metadata
105+
permute_node.meta = output_node.meta
106+
# Without "qnn_permute", we might obtain wrong input shape
107+
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
108+
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
73109

74110
# clear current stack
75111
self.nodes = []
76112

77113
def call(self, graph_module: torch.fx.GraphModule):
114+
self._clone_transpose(graph_module)
78115
self._fuse(graph_module)
79116
graph_module.recompile()
80117
dead_code_elimination_pass(graph_module)

backends/qualcomm/_passes/layout_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class LayoutTransform(ExportPass):
3030
"""
3131

3232
layout_sensitive_ops = {
33+
exir_ops.edge.aten.adaptive_avg_pool2d.default,
3334
exir_ops.edge.aten.avg_pool2d.default,
3435
exir_ops.edge.aten.convolution.default,
3536
exir_ops.edge.aten.max_pool2d_with_indices.default,

backends/qualcomm/builders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import (
88
node_visitor,
99
op_abs,
10+
op_adaptive_avg_pool2d,
1011
op_add,
1112
op_arange,
1213
op_avg_pool2d,
@@ -78,6 +79,7 @@
7879
__all__ = [
7980
node_visitor,
8081
op_abs,
82+
op_adaptive_avg_pool2d,
8183
op_add,
8284
op_arange,
8385
op_avg_pool2d,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
import numpy as np
11+
12+
import torch
13+
14+
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class AdaptiveAvgPool2D(NodeVisitor):
20+
target = ["aten.adaptive_avg_pool2d.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
31+
input_node = node.args[0]
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
41+
input_height = input_tensor.shape[1]
42+
input_width = input_tensor.shape[2]
43+
44+
output_height = node.args[1][0]
45+
output_width = node.args[1][1]
46+
47+
filter_height = input_height // output_height
48+
filter_width = input_width // output_width
49+
filter = [filter_height, filter_width]
50+
filter_shape = [len(filter)]
51+
52+
stride_height = filter_height
53+
stride_width = filter_width
54+
stride = [stride_height, stride_width]
55+
stride_shape = [len(stride)]
56+
57+
height = (output_height - 1) * stride_height + filter_height - input_height
58+
width = (output_width - 1) * stride_width + filter_width - input_width
59+
if height % 2 != 0 or width % 2 != 0:
60+
warnings.warn(
61+
"[QNN Delegate Op Builder]: Height or Width is not divisble by 2 with no remainder, fall back op",
62+
stacklevel=1,
63+
)
64+
return
65+
66+
padding_height = height / 2
67+
padding_width = width / 2
68+
padding = [padding_height, padding_width]
69+
padding_shape = [2, 2]
70+
71+
out_tensor = self.get_tensor(node, node)
72+
output_tensor_wrapper = self.define_tensor(
73+
node,
74+
node,
75+
out_tensor,
76+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
77+
nodes_to_wrappers,
78+
)
79+
80+
adaptive_avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
81+
node.name,
82+
QNN_OP_PACKAGE_NAME_QTI_AISW,
83+
OpPoolAvg2d.op_name,
84+
)
85+
86+
adaptive_avg_pool2d_op.AddInputTensors([input_tensor_wrapper])
87+
adaptive_avg_pool2d_op.AddOutputTensors([output_tensor_wrapper])
88+
89+
adaptive_avg_pool2d_op.AddTensorParam(
90+
OpPoolAvg2d.param_filter_size,
91+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
92+
len(filter_shape),
93+
filter_shape,
94+
np.array(
95+
filter,
96+
dtype=np.uint32,
97+
),
98+
True,
99+
)
100+
101+
adaptive_avg_pool2d_op.AddTensorParam(
102+
OpPoolAvg2d.param_stride,
103+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
104+
len(stride_shape),
105+
stride_shape,
106+
np.array(
107+
stride,
108+
dtype=np.uint32,
109+
),
110+
True,
111+
)
112+
113+
adaptive_avg_pool2d_op.AddTensorParam(
114+
OpPoolAvg2d.param_pad_amount,
115+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
116+
len(padding_shape),
117+
padding_shape,
118+
np.array(
119+
[[padding[0], padding[0]], [padding[1], padding[1]]],
120+
dtype=np.uint32,
121+
),
122+
True,
123+
)
124+
125+
return adaptive_avg_pool2d_op

backends/qualcomm/builders/op_layer_norm.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,19 @@ def define_node(
6363
nodes_to_wrappers,
6464
)
6565

66+
layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
67+
6668
bias_node = node.args[3]
67-
bias_tensor = get_parameter(bias_node, self.edge_program)
68-
bias_tensor_wrapper = self.define_tensor(
69-
bias_node,
70-
node,
71-
bias_tensor,
72-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
73-
nodes_to_wrappers,
74-
)
69+
if bias_node is not None:
70+
bias_tensor = get_parameter(bias_node, self.edge_program)
71+
bias_tensor_wrapper = self.define_tensor(
72+
bias_node,
73+
node,
74+
bias_tensor,
75+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
76+
nodes_to_wrappers,
77+
)
78+
layer_norm_input_tensors.append(bias_tensor_wrapper)
7579

7680
epsilon = node.args[4]
7781

@@ -89,9 +93,7 @@ def define_node(
8993
QNN_OP_PACKAGE_NAME_QTI_AISW,
9094
OpLayerNorm.op_name,
9195
)
92-
layer_norm_op.AddInputTensors(
93-
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
94-
)
96+
layer_norm_op.AddInputTensors(layer_norm_input_tensors)
9597
layer_norm_op.AddOutputTensors([output_tensor_wrapper])
9698
layer_norm_op.AddScalarParam(
9799
OpLayerNorm.param_epsilon,

backends/qualcomm/builders/op_rms_norm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def define_node(
6666
nodes_to_wrappers,
6767
)
6868

69-
# Fake node, nn moudle seems to be inconsistant with document
69+
# Fake node, nn module seems to be inconsistant with document
7070
bias_tensor = torch.zeros(weight_tensor.shape)
7171
bias_node = torch.fx.Node(
7272
node.graph,

backends/qualcomm/quantizer/annotators.py

+5
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,11 @@ def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
512512
annotate_single_in_single_out(node, quantization_config)
513513

514514

515+
@register_annotator([torch.ops.aten.square.default])
516+
def annotate_square(node: Node, quantization_config: QuantizationConfig) -> None:
517+
annotate_single_in_single_out(node, quantization_config)
518+
519+
515520
@register_annotator([torch.ops.aten.gelu.default])
516521
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
517522
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/tests/models.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ def forward(self, x):
1616
return torch.abs(x)
1717

1818

19+
class AdaptiveAvgPool2D(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(self, x):
24+
adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
25+
return adaptive_avg_pool(x)
26+
27+
1928
class Add(torch.nn.Module):
2029
def __init__(self):
2130
super().__init__()
@@ -685,15 +694,24 @@ def forward(self, x):
685694

686695

687696
class LayerNorm(torch.nn.Module):
688-
def __init__(self):
697+
def __init__(self, bias=True):
689698
super().__init__()
690-
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
699+
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias)
691700
self.linear = torch.nn.Linear(768, 196)
692701

693702
def forward(self, x):
694703
return self.linear(self.layer_norm(x))
695704

696705

706+
class LayerNormAdd(torch.nn.Module):
707+
def __init__(self):
708+
super().__init__()
709+
self.layer_norm = torch.nn.LayerNorm([512], eps=1e-6, bias=False)
710+
711+
def forward(self, x, y):
712+
return self.layer_norm(x) + y
713+
714+
697715
class LeakyReLUDefault(torch.nn.Module):
698716
def __init__(self):
699717
super().__init__()

0 commit comments

Comments
 (0)