Skip to content

Commit 90c24a8

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
update SqueezeInt4LinearInputs to process relu/gelu inputs too
Summary: Update/rename SqueezeInt4LinearInputs pass so it wraps gelu/relu with squeeze/unsqueeze view ops too Differential Revision: D69673068
1 parent da17f66 commit 90c24a8

File tree

5 files changed

+30
-10
lines changed

5 files changed

+30
-10
lines changed

Diff for: backends/transforms/fuse_view_copy.py

+15
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,23 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
3939
graph.eliminate_dead_code()
4040
return graph
4141

42+
def remove_noop_view_copy(graph: torch.fx.Graph) -> torch.fx.Graph:
43+
"""
44+
Remove view_copy nodes that are no-ops.
45+
"""
46+
ops = exir_ops.edge
47+
view_op = ops.aten.view_copy.default
48+
for node in graph.nodes:
49+
if node.op == "call_function" and node.target == view_op:
50+
input_shape = list(node.args[0].meta["val"].shape)
51+
target_shape = node.args[1]
52+
if input_shape == target_shape:
53+
node.replace_all_uses_with(node.args[0])
54+
graph.eliminate_dead_code()
55+
return graph
4256

4357
class FuseViewCopyTransform(ExportPass):
4458
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4559
graph_module.graph = merge_view_copy_chains(graph_module.graph)
60+
graph_module.graph = remove_noop_view_copy(graph_module.graph)
4661
return PassResult(graph_module, True)

Diff for: backends/vulkan/_passes/TARGETS

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ runtime.python_library(
3131
)
3232

3333
runtime.python_library(
34-
name = "squeeze_int4_linear_inputs",
34+
name = "squeeze_unsqueeze_inputs",
3535
srcs = [
36-
"squeeze_int4_linear_inputs.py",
36+
"squeeze_unsqueeze_inputs.py",
3737
],
3838
visibility = [
3939
"//executorch/backends/...",
@@ -114,7 +114,7 @@ runtime.python_library(
114114
":remove_asserts",
115115
":remove_local_scalar_dense",
116116
":remove_redundant_ops",
117-
":squeeze_int4_linear_inputs",
117+
":squeeze_unsqueeze_inputs",
118118
":tag_memory_meta_pass",
119119
]
120120
)

Diff for: backends/vulkan/_passes/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from executorch.backends.vulkan._passes.remove_redundant_ops import (
2121
RemoveRedundantOpsTransform,
2222
)
23-
from executorch.backends.vulkan._passes.squeeze_int4_linear_inputs import (
24-
SqueezeInt4LinearInputs,
23+
from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import (
24+
SqueezeUnsqueezeInputs,
2525
)
2626
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
2727

@@ -32,6 +32,6 @@
3232
"RemoveAssertsTransform",
3333
"RemoveLocalScalarDenseOpsTransform",
3434
"RemoveRedundantOpsTransform",
35-
"SqueezeInt4LinearInputs",
35+
"SqueezeUnsqueezeInputs",
3636
"TagMemoryMetaPass",
3737
]

Diff for: backends/vulkan/_passes/squeeze_int4_linear_inputs.py renamed to backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414

1515
from torch.fx.node import Argument
1616

17+
class SqueezeUnsqueezeInputs(ExportPass):
18+
_squeezable_ops = [
19+
exir_ops.edge.et_vk.linear_weight_int4.default,
20+
exir_ops.edge.aten.relu.default,
21+
exir_ops.edge.aten.gelu.default,
22+
]
1723

18-
class SqueezeInt4LinearInputs(ExportPass):
1924
def call_operator(
2025
self,
2126
op, # pyre-ignore
@@ -26,7 +31,7 @@ def call_operator(
2631
def _squeezable(shape: List[int]) -> bool:
2732
return len(shape) > 2 and 1 in shape
2833

29-
if op != exir_ops.edge.et_vk.linear_weight_int4.default:
34+
if op not in self._squeezable_ops:
3035
return super().call_operator(op, args, kwargs, meta)
3136

3237
# pyre-ignore[16]: `None` has no attribute `node`

Diff for: backends/vulkan/vulkan_preprocess.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
insert_prepack_nodes,
2727
RemoveLocalScalarDenseOpsTransform,
2828
RemoveRedundantOpsTransform,
29-
SqueezeInt4LinearInputs,
29+
SqueezeUnsqueezeInputs,
3030
TagMemoryMetaPass,
3131
)
3232

@@ -153,7 +153,7 @@ def preprocess( # noqa: C901
153153
RemoveRedundantOpsTransform(),
154154
AddmmToLinearTransform(),
155155
FuseDequantLinearPass(),
156-
SqueezeInt4LinearInputs(),
156+
SqueezeUnsqueezeInputs(),
157157
FuseViewCopyTransform(),
158158
ViewCopyToSqueezeUnsqueezePass(),
159159
FuseBatchNormWithConvPass(program),

0 commit comments

Comments
 (0)