Skip to content

Commit 917f9e2

Browse files
author
Zonglin Peng
committed
Revert "Qualcomm AI Engine Direct - ConvFormer Enablement (pytorch#6654)"
This reverts commit e8b5987.
1 parent e8b5987 commit 917f9e2

File tree

11 files changed

+48
-428
lines changed

11 files changed

+48
-428
lines changed

backends/qualcomm/_passes/fuse_consecutive_transpose.py

Lines changed: 24 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,8 @@
1515

1616
class FuseConsecutiveTranspose(ExportPass):
1717
"""
18-
This pass fuses consecutive transpose / permute into one or none to reduce runtime
19-
overhead.
20-
To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node
21-
by cloning transpose.
22-
Example:
23-
Before clone transpose:
24-
relu -> permute1 ─> permute2
25-
|──────> permute3
26-
27-
After clone transpose:
28-
relu ─> permute1 ──────> permute2
29-
|───> permute4(new) ─> permute3
18+
This pass fuses consecutive transpose / permute into one to reduce runtime
19+
overhead
3020
"""
3121

3222
def __init__(self):
@@ -37,81 +27,54 @@ def __init__(self):
3727
self.visited = set()
3828
self.nodes = []
3929

40-
def _clone_transpose(
41-
self, graph_module: torch.fx.GraphModule
42-
) -> torch.fx.GraphModule:
43-
graph = graph_module.graph
44-
for n in graph_module.graph.nodes:
45-
if n.target in self.op_map:
46-
users = [user for user in list(n.users) if user.target in self.op_map]
47-
if len(users) > 1:
48-
for i in range(1, len(users)):
49-
with graph.inserting_after(n):
50-
clone_permute_node = graph.create_node(
51-
"call_function",
52-
exir_ops.edge.aten.permute_copy.default,
53-
(n.args[0], n.args[1]),
54-
)
55-
clone_permute_node.meta = n.meta
56-
users[i].replace_input_with(n, clone_permute_node)
57-
58-
def _is_dispensable(self, axis_order):
59-
for index, value in enumerate(axis_order):
60-
if index != value:
61-
return False
62-
return True
63-
6430
def _traverse(self, node):
6531
if node in self.visited or node.target not in self.op_map:
6632
return
6733

6834
self.nodes.append(node)
6935
self.visited.add(node)
7036
next_users = [n for n in list(node.users) if n.target in self.op_map]
71-
72-
assert (
73-
len(next_users) <= 1
74-
), "Each permute node should have at most 1 permute output node after _clone_transpose"
7537
if not next_users:
7638
return
77-
else:
39+
40+
if len(next_users) == 1:
7841
self._traverse(list(node.users)[0])
42+
else:
43+
raise NotImplementedError(
44+
f"Check the node {node}, wich encounter mutilple permute output case"
45+
)
7946

8047
def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
8148
graph = graph_module.graph
8249
for n in graph_module.graph.nodes:
8350
self._traverse(n)
8451
if len(self.nodes) > 1:
52+
permute_order = []
8553
input_node, output_node = self.nodes[0].args[0], self.nodes[-1]
8654
input_shape = input_node.meta["val"].shape
8755
axis_order = torch.arange(len(input_shape)).tolist()
8856
for node in self.nodes:
57+
permute_order.append(node.args[1])
8958
axis_order = [axis_order[i] for i in node.args[1]]
90-
# If axis order is just [0,1,2,3], we ignore permute node
91-
if self._is_dispensable(axis_order):
92-
for user in output_node.users.copy():
93-
user.replace_input_with(output_node, n.args[0])
94-
else:
95-
with graph.inserting_after(input_node):
96-
permute_op = exir_ops.edge.aten.permute_copy.default
97-
permute_node = graph.create_node(
98-
"call_function", permute_op, (input_node, axis_order)
99-
)
100-
users = output_node.users.copy()
101-
for user in users:
102-
user.replace_input_with(output_node, permute_node)
103-
104-
# copy metadata
105-
permute_node.meta = output_node.meta
106-
# Without "qnn_permute", we might obtain wrong input shape
107-
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
108-
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
59+
with graph.inserting_after(input_node):
60+
permute_op = exir_ops.edge.aten.permute_copy.default
61+
permute_node = graph.create_node(
62+
"call_function", permute_op, (input_node, axis_order)
63+
)
64+
users = output_node.users.copy()
65+
for user in users:
66+
user.replace_input_with(output_node, permute_node)
67+
68+
# copy metadata
69+
permute_node.meta = output_node.meta
70+
# Without "qnn_permute", we might obtain wrong input shape
71+
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
72+
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
10973

11074
# clear current stack
11175
self.nodes = []
11276

11377
def call(self, graph_module: torch.fx.GraphModule):
114-
self._clone_transpose(graph_module)
11578
self._fuse(graph_module)
11679
graph_module.recompile()
11780
dead_code_elimination_pass(graph_module)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class LayoutTransform(ExportPass):
3030
"""
3131

3232
layout_sensitive_ops = {
33-
exir_ops.edge.aten.adaptive_avg_pool2d.default,
3433
exir_ops.edge.aten.avg_pool2d.default,
3534
exir_ops.edge.aten.convolution.default,
3635
exir_ops.edge.aten.max_pool2d_with_indices.default,

backends/qualcomm/builders/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from . import (
88
node_visitor,
99
op_abs,
10-
op_adaptive_avg_pool2d,
1110
op_add,
1211
op_arange,
1312
op_avg_pool2d,
@@ -79,7 +78,6 @@
7978
__all__ = [
8079
node_visitor,
8180
op_abs,
82-
op_adaptive_avg_pool2d,
8381
op_add,
8482
op_arange,
8583
op_avg_pool2d,

backends/qualcomm/builders/op_adaptive_avg_pool2d.py

Lines changed: 0 additions & 125 deletions
This file was deleted.

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,15 @@ def define_node(
6363
nodes_to_wrappers,
6464
)
6565

66-
layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
67-
6866
bias_node = node.args[3]
69-
if bias_node is not None:
70-
bias_tensor = get_parameter(bias_node, self.edge_program)
71-
bias_tensor_wrapper = self.define_tensor(
72-
bias_node,
73-
node,
74-
bias_tensor,
75-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
76-
nodes_to_wrappers,
77-
)
78-
layer_norm_input_tensors.append(bias_tensor_wrapper)
67+
bias_tensor = get_parameter(bias_node, self.edge_program)
68+
bias_tensor_wrapper = self.define_tensor(
69+
bias_node,
70+
node,
71+
bias_tensor,
72+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
73+
nodes_to_wrappers,
74+
)
7975

8076
epsilon = node.args[4]
8177

@@ -93,7 +89,9 @@ def define_node(
9389
QNN_OP_PACKAGE_NAME_QTI_AISW,
9490
OpLayerNorm.op_name,
9591
)
96-
layer_norm_op.AddInputTensors(layer_norm_input_tensors)
92+
layer_norm_op.AddInputTensors(
93+
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
94+
)
9795
layer_norm_op.AddOutputTensors([output_tensor_wrapper])
9896
layer_norm_op.AddScalarParam(
9997
OpLayerNorm.param_epsilon,

backends/qualcomm/builders/op_rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def define_node(
6666
nodes_to_wrappers,
6767
)
6868

69-
# Fake node, nn module seems to be inconsistant with document
69+
# Fake node, nn moudle seems to be inconsistant with document
7070
bias_tensor = torch.zeros(weight_tensor.shape)
7171
bias_node = torch.fx.Node(
7272
node.graph,

backends/qualcomm/quantizer/annotators.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,6 @@ def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
512512
annotate_single_in_single_out(node, quantization_config)
513513

514514

515-
@register_annotator([torch.ops.aten.square.default])
516-
def annotate_square(node: Node, quantization_config: QuantizationConfig) -> None:
517-
annotate_single_in_single_out(node, quantization_config)
518-
519-
520515
@register_annotator([torch.ops.aten.gelu.default])
521516
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
522517
annotate_single_in_single_out(node, quantization_config)

backends/qualcomm/tests/models.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,6 @@ def forward(self, x):
1616
return torch.abs(x)
1717

1818

19-
class AdaptiveAvgPool2D(torch.nn.Module):
20-
def __init__(self):
21-
super().__init__()
22-
23-
def forward(self, x):
24-
adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
25-
return adaptive_avg_pool(x)
26-
27-
2819
class Add(torch.nn.Module):
2920
def __init__(self):
3021
super().__init__()
@@ -694,24 +685,15 @@ def forward(self, x):
694685

695686

696687
class LayerNorm(torch.nn.Module):
697-
def __init__(self, bias=True):
688+
def __init__(self):
698689
super().__init__()
699-
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias)
690+
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
700691
self.linear = torch.nn.Linear(768, 196)
701692

702693
def forward(self, x):
703694
return self.linear(self.layer_norm(x))
704695

705696

706-
class LayerNormAdd(torch.nn.Module):
707-
def __init__(self):
708-
super().__init__()
709-
self.layer_norm = torch.nn.LayerNorm([512], eps=1e-6, bias=False)
710-
711-
def forward(self, x, y):
712-
return self.layer_norm(x) + y
713-
714-
715697
class LeakyReLUDefault(torch.nn.Module):
716698
def __init__(self):
717699
super().__init__()

0 commit comments

Comments
 (0)