diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index ddca8ea4a06..09b5b50e8c5 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -38,6 +38,7 @@ ) from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa +from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa from .insert_rescales_pass import InsertRescalePass # noqa from .insert_table_ops import InsertTableOpsPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index dd4ca7ad7bd..0dd4f67cf7c 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -39,6 +39,7 @@ FoldAndAnnotateQParamsPass, FuseBatchnorm2DPass, FuseConstantArgsPass, + FuseEqualPlaceholdersPass, FuseQuantizedActivationPass, InsertRescalePass, InsertTableOpsPass, @@ -112,6 +113,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(FuseConstantArgsPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) + self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) @@ -162,6 +164,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) + self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py new file mode 100644 index 00000000000..cd8cce1b3ea --- /dev/null +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -0,0 +1,83 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.arm_pass_utils import ( + get_constant_placeholder_kind, + get_param_tensor, + is_param_node, +) +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) +from executorch.exir import ExportedProgram +from executorch.exir.pass_base import ExportPass, PassResult + + +class FuseEqualPlaceholdersPass(ExportPass): + """ + This pass optimizes memory usage by finding constant placeholders + pointing to identical tensors and fusing them to one single placeholder + with multiple users. + """ + + def __init__(self, exported_program: ExportedProgram): + self.exported_program = exported_program + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + const_placeholder_nodes = [] + for node in graph_module.graph.nodes: + if is_param_node(self.exported_program, node): + const_placeholder_nodes.append(node) + + while const_placeholder_nodes: + + # Find equal tensors + node1 = const_placeholder_nodes.pop() + eq_nodes = [node1] + tensor1 = get_param_tensor(self.exported_program, node1) + if tensor1 is None: + continue + + for node2 in const_placeholder_nodes: + tensor2 = get_param_tensor(self.exported_program, node2) + if tensor2 is None: + continue + + if torch.equal(tensor1, tensor2): + eq_nodes.append(node2) + + if len(eq_nodes) > 1: + common_name = node1.name + "_common" + common_kind = get_constant_placeholder_kind( + self.exported_program, node1 + ) + common_persisten_buffer = True + + with graph_module.graph.inserting_before(node1): + common_node = create_constant_placeholder( + self.exported_program, + graph_module.graph, + common_name, + common_kind, + tensor1, + common_persisten_buffer, + ) + + for eq_node in eq_nodes: + eq_node.replace_all_uses_with(common_node) + delete_constant_placeholder(self.exported_program, eq_node) + if eq_node != node1: + const_placeholder_nodes.remove(eq_node) + + modified = True + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module=graph_module, modified=modified) diff --git a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py new file mode 100644 index 00000000000..2674f45cf6a --- /dev/null +++ b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py @@ -0,0 +1,96 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from typing import Tuple + +import torch +from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( + FuseEqualPlaceholdersPass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] # Input x + + +class FuseWeightsConstants(torch.nn.Module): + ops_before_pass = {} + ops_after_pass = {} + ops_not_after_pass = [] + + def __init__( + self, + ): + super().__init__() + self.weights1 = torch.rand(1, 2, 1) + self.weights2 = deepcopy(self.weights1) + self.bias1 = torch.rand(1) + self.bias2 = deepcopy(self.bias1) + self.bias3 = deepcopy(self.bias1) + + def forward(self, x): + return ( + torch.conv1d(x, self.weights1, self.bias1) + + torch.conv1d(x, self.weights2, self.bias2) + + self.bias3 + ) + + +class FuseWeightsStateDict(torch.nn.Module): + ops_before_pass = {} + ops_after_pass = {} + ops_not_after_pass = [] + + def __init__( + self, + ): + super().__init__() + self.fc1 = torch.nn.Linear(in_features=8, out_features=2, bias=True) + self.fc2 = deepcopy(self.fc1) + + def forward(self, x): + return self.fc1(x) + self.fc2(x) + + +def test_fuse_equal_placeholders_constants_tosa_MI(): + module = FuseWeightsConstants() + data = (torch.rand(1, 2, 8),) + pipeline = PassPipeline[input_t]( + module, + data, + tosa_version="TOSA-0.80+MI", + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + passes_with_exported_program=[FuseEqualPlaceholdersPass], + ) + pipeline.run() + + # Check that weights and bias has been merged. + exp_program = pipeline.tester.get_artifact().exported_program() + constant_keys = list(exp_program.constants.keys()) + assert len(constant_keys) == 2, "FuseEqualPlaceholders constants failed" + assert "_common" in constant_keys[0], "FuseEqualPlaceholders constants failed" + assert "_common" in constant_keys[1], "FuseEqualPlaceholders constants failed" + + +def test_fuse_equal_placeholders_state_dict_tosa_MI(): + module = FuseWeightsStateDict() + data = (torch.rand(1, 2, 8),) + pipeline = PassPipeline[input_t]( + module, + data, + tosa_version="TOSA-0.80+MI", + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + passes_with_exported_program=[FuseEqualPlaceholdersPass], + ) + pipeline.run() + + # Check that weights and bias has been merged. + exp_program = pipeline.tester.get_artifact().exported_program() + state_dict_keys = list(exp_program.state_dict.keys()) + assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed" + assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed" + assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed" diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 8d77eabce0f..44f09f211b5 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -57,7 +57,7 @@ fi # vela vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela" -vela_rev="425541302c7e4b6fbeca7c0061286b131ee507c3" +vela_rev="859cc066178a87ff28230c1ce9bd370f1e98aa5a" ######## ### Optional user args