-
Notifications
You must be signed in to change notification settings - Fork 551
Arm backend: Add FuseEqualPlaceholdersPass #9893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.backends.arm._passes.arm_pass_utils import ( | ||
get_constant_placeholder_kind, | ||
get_param_tensor, | ||
is_param_node, | ||
) | ||
from executorch.backends.transforms.utils import ( | ||
create_constant_placeholder, | ||
delete_constant_placeholder, | ||
) | ||
from executorch.exir import ExportedProgram | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class FuseEqualPlaceholdersPass(ExportPass): | ||
""" | ||
This pass optimizes memory usage by finding constant placeholders | ||
pointing to identical tensors and fusing them to one single placeholder | ||
with multiple users. | ||
""" | ||
|
||
def __init__(self, exported_program: ExportedProgram): | ||
self.exported_program = exported_program | ||
super().__init__() | ||
|
||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
modified = False | ||
const_placeholder_nodes = [] | ||
for node in graph_module.graph.nodes: | ||
if is_param_node(self.exported_program, node): | ||
const_placeholder_nodes.append(node) | ||
|
||
while const_placeholder_nodes: | ||
|
||
# Find equal tensors | ||
node1 = const_placeholder_nodes.pop() | ||
eq_nodes = [node1] | ||
tensor1 = get_param_tensor(self.exported_program, node1) | ||
if tensor1 is None: | ||
continue | ||
|
||
for node2 in const_placeholder_nodes: | ||
tensor2 = get_param_tensor(self.exported_program, node2) | ||
if tensor2 is None: | ||
continue | ||
|
||
if torch.equal(tensor1, tensor2): | ||
eq_nodes.append(node2) | ||
|
||
if len(eq_nodes) > 1: | ||
common_name = node1.name + "_common" | ||
common_kind = get_constant_placeholder_kind( | ||
self.exported_program, node1 | ||
) | ||
common_persisten_buffer = True | ||
|
||
with graph_module.graph.inserting_before(node1): | ||
common_node = create_constant_placeholder( | ||
self.exported_program, | ||
graph_module.graph, | ||
common_name, | ||
common_kind, | ||
tensor1, | ||
common_persisten_buffer, | ||
) | ||
|
||
for eq_node in eq_nodes: | ||
eq_node.replace_all_uses_with(common_node) | ||
delete_constant_placeholder(self.exported_program, eq_node) | ||
if eq_node != node1: | ||
const_placeholder_nodes.remove(eq_node) | ||
|
||
modified = True | ||
|
||
if modified: | ||
graph_module.recompile() | ||
graph_module = super().call(graph_module).graph_module | ||
return PassResult(graph_module=graph_module, modified=modified) |
96 changes: 96 additions & 0 deletions
96
backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from copy import deepcopy | ||
from typing import Tuple | ||
|
||
import torch | ||
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( | ||
FuseEqualPlaceholdersPass, | ||
) | ||
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline | ||
|
||
input_t = Tuple[torch.Tensor] # Input x | ||
|
||
|
||
class FuseWeightsConstants(torch.nn.Module): | ||
ops_before_pass = {} | ||
ops_after_pass = {} | ||
ops_not_after_pass = [] | ||
|
||
def __init__( | ||
self, | ||
): | ||
super().__init__() | ||
self.weights1 = torch.rand(1, 2, 1) | ||
self.weights2 = deepcopy(self.weights1) | ||
self.bias1 = torch.rand(1) | ||
self.bias2 = deepcopy(self.bias1) | ||
self.bias3 = deepcopy(self.bias1) | ||
|
||
def forward(self, x): | ||
return ( | ||
torch.conv1d(x, self.weights1, self.bias1) | ||
+ torch.conv1d(x, self.weights2, self.bias2) | ||
+ self.bias3 | ||
) | ||
|
||
|
||
class FuseWeightsStateDict(torch.nn.Module): | ||
ops_before_pass = {} | ||
ops_after_pass = {} | ||
ops_not_after_pass = [] | ||
|
||
def __init__( | ||
self, | ||
): | ||
super().__init__() | ||
self.fc1 = torch.nn.Linear(in_features=8, out_features=2, bias=True) | ||
self.fc2 = deepcopy(self.fc1) | ||
|
||
def forward(self, x): | ||
return self.fc1(x) + self.fc2(x) | ||
|
||
|
||
def test_fuse_equal_placeholders_constants_tosa_MI(): | ||
module = FuseWeightsConstants() | ||
data = (torch.rand(1, 2, 8),) | ||
pipeline = PassPipeline[input_t]( | ||
module, | ||
data, | ||
tosa_version="TOSA-0.80+MI", | ||
ops_before_pass=module.ops_before_pass, | ||
ops_after_pass=module.ops_after_pass, | ||
passes_with_exported_program=[FuseEqualPlaceholdersPass], | ||
) | ||
pipeline.run() | ||
|
||
# Check that weights and bias has been merged. | ||
exp_program = pipeline.tester.get_artifact().exported_program() | ||
constant_keys = list(exp_program.constants.keys()) | ||
assert len(constant_keys) == 2, "FuseEqualPlaceholders constants failed" | ||
assert "_common" in constant_keys[0], "FuseEqualPlaceholders constants failed" | ||
assert "_common" in constant_keys[1], "FuseEqualPlaceholders constants failed" | ||
|
||
|
||
def test_fuse_equal_placeholders_state_dict_tosa_MI(): | ||
module = FuseWeightsStateDict() | ||
data = (torch.rand(1, 2, 8),) | ||
pipeline = PassPipeline[input_t]( | ||
module, | ||
data, | ||
tosa_version="TOSA-0.80+MI", | ||
ops_before_pass=module.ops_before_pass, | ||
ops_after_pass=module.ops_after_pass, | ||
passes_with_exported_program=[FuseEqualPlaceholdersPass], | ||
) | ||
pipeline.run() | ||
|
||
# Check that weights and bias has been merged. | ||
exp_program = pipeline.tester.get_artifact().exported_program() | ||
state_dict_keys = list(exp_program.state_dict.keys()) | ||
assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed" | ||
assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed" | ||
assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the heads up @zingo / @AdrianLundell . I think we will pull this one in.