Skip to content

Commit a4def9f

Browse files
Make ArmPassManager aware of TosaSpecification (#7668)
- Pass TosaSpecifcation to ArmPassManager. Based on this the PassManager can decide which passes should be run. - Also adds docstrings and renames some passes. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 6aa5c8a commit a4def9f

21 files changed

+225
-181
lines changed

backends/arm/_passes/arm_pass_manager.py

+63-52
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
# pyre-unsafe
99

10-
import torch
1110
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1211
AnnotateChannelsLastDimOrder,
1312
)
@@ -47,7 +46,7 @@
4746
)
4847
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
4948
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
50-
ConvertMeanDimToAveragePool,
49+
ConvertMeanDimToAveragePoolPass,
5150
)
5251
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
5352
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
@@ -61,86 +60,98 @@
6160
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
6261
UnsqueezeScalarPlaceholdersPass,
6362
)
63+
from executorch.backends.arm.tosa_specification import TosaSpecification
6464
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6565
from executorch.exir import ExportedProgram
66-
from executorch.exir.dialects._ops import ops as exir_ops
6766
from executorch.exir.pass_manager import PassManager
67+
from torch.fx import GraphModule
6868

6969

7070
class ArmPassManager(PassManager):
7171

72-
def _transform(self, graph_module: torch.fx.GraphModule):
72+
def __init__(self, tosa_spec: TosaSpecification) -> None:
73+
self.tosa_spec = tosa_spec
74+
super().__init__()
75+
76+
def _transform(self, graph_module: GraphModule):
7377
return self(graph_module).graph_module
7478

75-
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
76-
"""Apply passes before transforming program to backend"""
79+
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
7780
self.add_pass(FuseQuantizedActivationPass())
81+
self.add_pass(RemoveGetItemPass())
82+
self.add_pass(ConvertSplitToSlicePass())
83+
self.add_pass(ConvertMmToBmmPass())
7884
self.add_pass(DecomposeLinearPass())
85+
self.add_pass(ConvertMeanDimToAveragePoolPass())
86+
87+
self.add_pass(AnnotateDecomposedMatmulPass())
88+
self.add_pass(QuantizeFullArgument())
89+
self.add_pass(FoldAndAnnotateQParamsPass())
90+
self.add_pass(RetraceFoldedDtypesPass())
91+
self.add_pass(InsertTableOpsPass(exported_program))
92+
93+
self.add_pass(RemoveClonePass())
94+
self.add_pass(SizeAdjustConv2DPass())
95+
self.add_pass(ConvertExpandCopyToRepeatPass())
96+
self.add_pass(UnsqueezeBeforeRepeatPass())
97+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
98+
self.add_pass(CastInt64ToInt32Pass(exported_program))
99+
self.add_pass(MatchArgRanksPass(exported_program))
100+
self.add_pass(KeepDimsFalseToSqueezePass())
101+
self.add_pass(Conv1dUnsqueezePass(exported_program))
102+
self.add_pass(DecomposeSelectPass())
103+
104+
self.add_pass(AnnotateChannelsLastDimOrder())
105+
106+
return self._transform(exported_program.graph_module)
107+
108+
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109+
110+
self.add_pass(FuseQuantizedActivationPass())
79111
self.add_pass(RemoveGetItemPass())
112+
self.add_pass(ConvertSplitToSlicePass())
113+
self.add_pass(ConvertMmToBmmPass())
114+
self.add_pass(DecomposeLinearPass())
80115
self.add_pass(DecomposeLayerNormPass())
81116
self.add_pass(DecomposeVarPass())
82-
self.add_pass(ConvertMeanDimToAveragePool())
83117
self.add_pass(DecomposeMeanDimPass())
84-
self.add_pass(ConvertSplitToSlicePass())
85-
self.add_pass(ConvertMmToBmmPass())
86-
# TODO MLETORCH-558
118+
self.add_pass(ConvertMeanDimToAveragePoolPass())
119+
self.add_pass(DecomposeDivPass())
120+
self.add_pass(DecomposeSoftmaxesPass())
121+
87122
self.add_pass(AnnotateDecomposedMatmulPass())
88123
self.add_pass(QuantizeFullArgument())
89-
self.add_pass(
90-
FoldAndAnnotateQParamsPass(
91-
[
92-
exir_ops.edge.aten.minimum.default,
93-
exir_ops.edge.aten.maximum.default,
94-
exir_ops.edge.aten.add.Tensor,
95-
exir_ops.edge.aten.avg_pool2d.default,
96-
exir_ops.edge.aten.bmm.default,
97-
exir_ops.edge.aten.cat.default,
98-
exir_ops.edge.aten.convolution.default,
99-
exir_ops.edge.aten.clone.default,
100-
exir_ops.edge.aten.exp.default,
101-
exir_ops.edge.aten.expand_copy.default,
102-
exir_ops.edge.aten.full.default,
103-
exir_ops.edge.aten.hardtanh.default,
104-
exir_ops.edge.aten.log.default,
105-
exir_ops.edge.aten.max_pool2d.default,
106-
exir_ops.edge.aten.mul.Tensor,
107-
exir_ops.edge.aten.permute_copy.default,
108-
exir_ops.edge.aten.reciprocal.default,
109-
exir_ops.edge.aten.relu.default,
110-
exir_ops.edge.aten.repeat.default,
111-
exir_ops.edge.aten.rsqrt.default,
112-
exir_ops.edge.aten.select_copy.int,
113-
exir_ops.edge.aten.sigmoid.default,
114-
exir_ops.edge.aten.slice_copy.Tensor,
115-
exir_ops.edge.aten.squeeze_copy.dims,
116-
exir_ops.edge.aten.sub.Tensor,
117-
exir_ops.edge.aten.sum.dim_IntList,
118-
exir_ops.edge.aten.tanh.default,
119-
exir_ops.edge.aten.unsqueeze_copy.default,
120-
exir_ops.edge.aten.upsample_nearest2d.vec,
121-
exir_ops.edge.aten.view_copy.default,
122-
]
123-
)
124-
)
124+
self.add_pass(FoldAndAnnotateQParamsPass())
125125
self.add_pass(RetraceFoldedDtypesPass())
126126
self.add_pass(InsertTableOpsPass(exported_program))
127+
128+
self.add_pass(RemoveClonePass())
129+
self.add_pass(SizeAdjustConv2DPass())
127130
self.add_pass(ConvertExpandCopyToRepeatPass())
128131
self.add_pass(UnsqueezeBeforeRepeatPass())
129-
self.add_pass(CastInt64ToInt32Pass(exported_program))
130132
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
131-
self.add_pass(SizeAdjustConv2DPass())
132-
self.add_pass(RemoveClonePass())
133+
self.add_pass(CastInt64ToInt32Pass(exported_program))
133134
self.add_pass(MatchArgRanksPass(exported_program))
134-
self.add_pass(DecomposeDivPass())
135135
self.add_pass(KeepDimsFalseToSqueezePass())
136136
self.add_pass(Conv1dUnsqueezePass(exported_program))
137-
self.add_pass(DecomposeSoftmaxesPass())
138137
self.add_pass(DecomposeSelectPass())
138+
139139
self.add_pass(AnnotateChannelsLastDimOrder())
140140

141141
return self._transform(exported_program.graph_module)
142142

143-
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
143+
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
144+
"""Apply passes before transforming program to backend"""
145+
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
146+
return self._tosa_080_BI_pipeline(exported_program)
147+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
148+
return self._tosa_080_MI_pipeline(exported_program)
149+
else:
150+
raise NotImplementedError(
151+
f"No pass pipeline implemented for {self.tosa_spec=}"
152+
)
153+
154+
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
144155
self.add_pass(ScalarsToAttributePass())
145156
self.add_pass(DecomposeLayerNormPass())
146157
self.add_pass(DecomposeVarPass())

backends/arm/_passes/cast_int64_pass.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -17,6 +17,10 @@
1717

1818

1919
class CastInt64ToInt32Pass(ExportPass):
20+
"""
21+
Cast int64 buffers to int32 if the int64 data is in int32 range.
22+
"""
23+
2024
def __init__(self, exported_program: torch.export.ExportedProgram):
2125
super(CastInt64ToInt32Pass, self).__init__()
2226
self.exported_program = exported_program

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
88

9-
from typing import cast, Dict, Iterable, Set, Tuple
9+
from typing import cast, Dict, Set, Tuple
1010

1111
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1212

@@ -55,7 +55,7 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
5555
class FoldAndAnnotateQParamsPass(ExportPass):
5656
"""
5757
A pass that walks the graph and removes any DQ and Q nodes before and after the target
58-
node in the supplied list of operators.
58+
node.
5959
The quantization parameters from the DQ/Q nodes are stored as meta values to be
6060
accessible for later lowering and serialization passes.
6161
The assumption is that the quantization annotatation adds DQ nodes for all tensor
@@ -82,9 +82,8 @@ class FoldAndAnnotateQParamsPass(ExportPass):
8282
8383
"""
8484

85-
def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
85+
def __init__(self) -> None:
8686
super().__init__()
87-
self.targeted_ops = targeted_ops
8887

8988
def fold_and_annotate_arg(
9089
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
@@ -131,7 +130,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
131130
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
132131
for n in graph_module.graph.nodes:
133132
n = cast(Node, n)
134-
if n.op != "call_function" or n.target not in self.targeted_ops:
133+
if n.op != "call_function":
135134
continue
136135

137136
# Make sure we haven't already set qparams meta information on the node
@@ -180,7 +179,7 @@ class QuantizeFullArgument(ExportPass):
180179

181180
def call(self, graph_module: GraphModule) -> PassResult:
182181
modified = False
183-
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
182+
# Loop over the graph nodes and find full.default nodes.
184183
for n in graph_module.graph.nodes:
185184
n = cast(Node, n)
186185
if n.target != exir_ops.edge.aten.full.default:

backends/arm/_passes/meandim_to_averagepool_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -16,7 +16,7 @@
1616
Argument = Any
1717

1818

19-
class ConvertMeanDimToAveragePool(ExportPass):
19+
class ConvertMeanDimToAveragePoolPass(ExportPass):
2020
"""
2121
Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation.
2222
"""

backends/arm/_passes/remove_clone_pass.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -11,6 +11,7 @@
1111

1212

1313
class RemoveClonePass(ExportPass):
14+
"""Remove all clones from graph_module"""
1415

1516
def call_operator(self, op, args, kwargs, meta):
1617
if op != exir_ops.edge.aten.clone.default:

backends/arm/arm_backend.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self):
5050
self.output_format = None
5151
self.path_for_intermediates = None
5252
self.quantize_io = False
53-
self.tosa_version = None
53+
self.tosa_spec = None
5454
self.input_order = None
5555

5656
def ethosu_compile_spec(
@@ -92,19 +92,26 @@ def ethosu_compile_spec(
9292
if "u55" in config:
9393
# Add the Ethos-U55 extension marker
9494
base_tosa_version += "+u55"
95-
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
95+
self.tosa_spec = TosaSpecification.create_from_string(base_tosa_version)
9696

9797
return self
9898

99-
def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
99+
def tosa_compile_spec(
100+
self, tosa_spec: str | TosaSpecification
101+
) -> "ArmCompileSpecBuilder":
100102
"""
101103
Generate compile spec for TOSA flatbuffer output
102104
"""
103105
assert (
104106
self.output_format is None
105107
), f"Output format already set: {self.output_format}"
106108
self.output_format = "tosa"
107-
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
109+
if isinstance(tosa_spec, TosaSpecification):
110+
self.tosa_spec = tosa_spec
111+
elif isinstance(tosa_spec, str):
112+
self.tosa_spec = TosaSpecification.create_from_string(tosa_spec)
113+
else:
114+
raise RuntimeError(f"Invalid type for {tosa_spec}!")
108115
return self
109116

110117
def dump_intermediate_artifacts_to(
@@ -138,12 +145,10 @@ def build(self) -> List[CompileSpec]:
138145
"""
139146
Generate a list of compile spec objects from the builder
140147
"""
141-
assert self.tosa_version
148+
assert self.tosa_spec
142149

143150
# Always supply a TOSA version
144-
self.compile_spec = [
145-
CompileSpec("tosa_version", str(self.tosa_version).encode())
146-
]
151+
self.compile_spec = [CompileSpec("tosa_version", str(self.tosa_spec).encode())]
147152

148153
if self.output_format == "vela":
149154
self.compile_spec += [
@@ -253,7 +258,7 @@ def preprocess( # noqa: C901
253258
# Converted output for this subgraph, serializer needs path early as it emits
254259
# const data directly. Path created and data written only in debug builds.
255260
tosa_graph = ts.TosaSerializer(artifact_path)
256-
graph_module = ArmPassManager().transform_to_backend_pipeline(
261+
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline(
257262
exported_program=edge_program
258263
)
259264

backends/arm/quantizer/arm_quantizer.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
@@ -24,6 +24,7 @@
2424
from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph
2525

2626
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
27+
from executorch.backends.arm.tosa_specification import TosaSpecification
2728
from torch.ao.quantization.fake_quantize import (
2829
FakeQuantize,
2930
FusedMovingAvgObsFakeQuantize,
@@ -205,8 +206,10 @@ def not_module_type_or_name_filter(n: Node) -> bool:
205206

206207

207208
class ArmQuantizer(Quantizer):
208-
def __init__(self) -> None:
209+
210+
def __init__(self, tosa_spec: TosaSpecification) -> None:
209211
super().__init__()
212+
self.tosa_spec = tosa_spec
210213
self.global_config: Optional[QuantizationConfig] = None
211214
self.io_config: Optional[QuantizationConfig] = None
212215
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
@@ -250,7 +253,9 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
250253
Currently transforms scalar values to tensor attributes.
251254
"""
252255

253-
return ArmPassManager().transform_for_annotation_pipeline(graph_module=model)
256+
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
257+
graph_module=model
258+
)
254259

255260
def annotate(self, model: GraphModule) -> GraphModule:
256261
"""Performs the quantization annotation on the graph.

0 commit comments

Comments
 (0)