Skip to content

Arm backend: Add FuseEqualPlaceholdersPass #9893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .insert_rescales_pass import InsertRescalePass # noqa
from .insert_table_ops import InsertTableOpsPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FoldAndAnnotateQParamsPass,
FuseBatchnorm2DPass,
FuseConstantArgsPass,
FuseEqualPlaceholdersPass,
FuseQuantizedActivationPass,
InsertRescalePass,
InsertTableOpsPass,
Expand Down Expand Up @@ -108,6 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(FuseConstantArgsPass(exported_program))

self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(InsertRescalePass())

Expand Down Expand Up @@ -155,6 +157,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(InsertRescalePass())

Expand Down
83 changes: 83 additions & 0 deletions backends/arm/_passes/fuse_equal_placeholders_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.arm_pass_utils import (
get_constant_placeholder_kind,
get_param_tensor,
is_param_node,
)
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)
from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass, PassResult


class FuseEqualPlaceholdersPass(ExportPass):
"""
This pass optimizes memory usage by finding constant placeholders
pointing to identical tensors and fusing them to one single placeholder
with multiple users.
"""

def __init__(self, exported_program: ExportedProgram):
self.exported_program = exported_program
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
const_placeholder_nodes = []
for node in graph_module.graph.nodes:
if is_param_node(self.exported_program, node):
const_placeholder_nodes.append(node)

while const_placeholder_nodes:

# Find equal tensors
node1 = const_placeholder_nodes.pop()
eq_nodes = [node1]
tensor1 = get_param_tensor(self.exported_program, node1)
if tensor1 is None:
continue

for node2 in const_placeholder_nodes:
tensor2 = get_param_tensor(self.exported_program, node2)
if tensor2 is None:
continue

if torch.equal(tensor1, tensor2):
eq_nodes.append(node2)

if len(eq_nodes) > 1:
common_name = node1.name + "_common"
common_kind = get_constant_placeholder_kind(
self.exported_program, node1
)
common_persisten_buffer = True

with graph_module.graph.inserting_before(node1):
common_node = create_constant_placeholder(
self.exported_program,
graph_module.graph,
common_name,
common_kind,
tensor1,
common_persisten_buffer,
)

for eq_node in eq_nodes:
eq_node.replace_all_uses_with(common_node)
delete_constant_placeholder(self.exported_program, eq_node)
if eq_node != node1:
const_placeholder_nodes.remove(eq_node)

modified = True

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module=graph_module, modified=modified)
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
from typing import Tuple

import torch
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
FuseEqualPlaceholdersPass,
)
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline

input_t = Tuple[torch.Tensor] # Input x


class FuseWeightsConstants(torch.nn.Module):
ops_before_pass = {}
ops_after_pass = {}
ops_not_after_pass = []

def __init__(
self,
):
super().__init__()
self.weights1 = torch.rand(1, 2, 1)
self.weights2 = deepcopy(self.weights1)
self.bias1 = torch.rand(1)
self.bias2 = deepcopy(self.bias1)
self.bias3 = deepcopy(self.bias1)

def forward(self, x):
return (
torch.conv1d(x, self.weights1, self.bias1)
+ torch.conv1d(x, self.weights2, self.bias2)
+ self.bias3
)


class FuseWeightsStateDict(torch.nn.Module):
ops_before_pass = {}
ops_after_pass = {}
ops_not_after_pass = []

def __init__(
self,
):
super().__init__()
self.fc1 = torch.nn.Linear(in_features=8, out_features=2, bias=True)
self.fc2 = deepcopy(self.fc1)

def forward(self, x):
return self.fc1(x) + self.fc2(x)


def test_fuse_equal_placeholders_constants_tosa_MI():
module = FuseWeightsConstants()
data = (torch.rand(1, 2, 8),)
pipeline = PassPipeline[input_t](
module,
data,
tosa_version="TOSA-0.80+MI",
ops_before_pass=module.ops_before_pass,
ops_after_pass=module.ops_after_pass,
passes_with_exported_program=[FuseEqualPlaceholdersPass],
)
pipeline.run()

# Check that weights and bias has been merged.
exp_program = pipeline.tester.get_artifact().exported_program()
constant_keys = list(exp_program.constants.keys())
assert len(constant_keys) == 2, "FuseEqualPlaceholders constants failed"
assert "_common" in constant_keys[0], "FuseEqualPlaceholders constants failed"
assert "_common" in constant_keys[1], "FuseEqualPlaceholders constants failed"


def test_fuse_equal_placeholders_state_dict_tosa_MI():
module = FuseWeightsStateDict()
data = (torch.rand(1, 2, 8),)
pipeline = PassPipeline[input_t](
module,
data,
tosa_version="TOSA-0.80+MI",
ops_before_pass=module.ops_before_pass,
ops_after_pass=module.ops_after_pass,
passes_with_exported_program=[FuseEqualPlaceholdersPass],
)
pipeline.run()

# Check that weights and bias has been merged.
exp_program = pipeline.tester.get_artifact().exported_program()
state_dict_keys = list(exp_program.state_dict.keys())
assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed"
assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed"
assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed"
2 changes: 1 addition & 1 deletion examples/arm/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ tosa_reference_model_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"

# vela
vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"
vela_rev="425541302c7e4b6fbeca7c0061286b131ee507c3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the heads up @zingo / @AdrianLundell . I think we will pull this one in.

vela_rev="859cc066178a87ff28230c1ce9bd370f1e98aa5a"

########
### Optional user args
Expand Down
Loading