Skip to content

Commit 3fe89b1

Browse files
authored
Merge branch 'main' into pr_model_improve
2 parents 50e4783 + 1505903 commit 3fe89b1

20 files changed

+541
-182
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .decompose_select import DecomposeSelectPass # noqa
2828
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
2929
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
30+
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
3031
from .decompose_var_pass import DecomposeVarPass # noqa
3132
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
3233
FoldAndAnnotateQParamsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeSelectPass,
3333
DecomposeSoftmaxPass,
3434
DecomposeSoftmaxUnstablePass,
35+
DecomposeSqrtPass,
3536
DecomposeVarPass,
3637
FoldAndAnnotateQParamsPass,
3738
FuseBatchnorm2DPass,
@@ -115,6 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115116
return self._transform(exported_program.graph_module)
116117

117118
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
119+
self.add_pass(DecomposeSqrtPass())
118120
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
119121
self.add_pass(FuseQuantizedActivationPass())
120122
self.add_pass(RemoveGetItemPass())
@@ -181,6 +183,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
181183
self.add_pass(DecomposeMeanDimPass())
182184
self.add_pass(DecomposeDivPass())
183185
self.add_pass(DecomposeLeakyReLUPass())
186+
self.add_pass(DecomposeSqrtPass())
184187

185188
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
186189
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,)
12+
aten_sqrt_ops = (
13+
torch.ops.aten.sqrt.default,
14+
torch.ops.aten.sqrt_.default,
15+
)
16+
17+
18+
def get_sqrt_decomposition(op) -> tuple:
19+
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
20+
if op in edge_sqrt_ops:
21+
return exir_ops.edge.aten.pow.Tensor_Scalar
22+
if op in aten_sqrt_ops:
23+
return torch.ops.aten.pow.Tensor_Scalar
24+
raise RuntimeError(f"Can't get sqrt decomposition for op {op}")
25+
26+
27+
class DecomposeSqrtPass(ExportPass):
28+
29+
def call_operator(self, op, args, kwargs, meta):
30+
"""
31+
Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.
32+
"""
33+
34+
if op not in (edge_sqrt_ops + aten_sqrt_ops):
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
pow_op = get_sqrt_decomposition(op)
38+
39+
return super().call_operator(pow_op, (args[0], 0.5), {}, meta)

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(self, exported_program):
4848
exir_ops.edge.aten.bitwise_right_shift.Tensor,
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
51+
exir_ops.edge.aten.gt.Tensor,
52+
exir_ops.edge.aten.lt.Tensor,
5153
exir_ops.edge.aten.pow.Tensor_Tensor,
5254
exir_ops.edge.aten.where.self,
5355
]

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@
2626
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
2727
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
2828
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
29+
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
30+
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
2931
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3032
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
3133
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
3234
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
3335
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
3436
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
3537
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
38+
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
39+
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
3640
}
3741

3842

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,10 @@ class EthosU55NotSupported(OperatorSupportBase):
135135
exir_ops.edge.aten.eq.Scalar,
136136
exir_ops.edge.aten.ge.Tensor,
137137
exir_ops.edge.aten.gt.Tensor,
138+
exir_ops.edge.aten.gt.Scalar,
138139
exir_ops.edge.aten.le.Tensor,
139140
exir_ops.edge.aten.lt.Tensor,
141+
exir_ops.edge.aten.lt.Scalar,
140142
exir_ops.edge.aten.flip.default, # REVERSE
141143
exir_ops.edge.aten.grid_sampler_2d, # GATHER
142144
exir_ops.edge.aten.scatter.src,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ def is_node_supported(
176176
exir_ops.edge.aten.full_like.default,
177177
exir_ops.edge.aten.ge.Tensor,
178178
exir_ops.edge.aten.gt.Tensor,
179+
exir_ops.edge.aten.gt.Scalar,
179180
exir_ops.edge.aten.le.Tensor,
180181
exir_ops.edge.aten.lt.Tensor,
182+
exir_ops.edge.aten.lt.Scalar,
181183
exir_ops.edge.aten.mul.Tensor,
182184
exir_ops.edge.aten.add.Scalar,
183185
exir_ops.edge.aten.sub.Scalar,
@@ -194,6 +196,7 @@ def is_node_supported(
194196
exir_ops.edge.aten.reciprocal.default,
195197
exir_ops.edge.aten.relu.default,
196198
exir_ops.edge.aten.leaky_relu.default,
199+
exir_ops.edge.aten.sqrt.default,
197200
exir_ops.edge.aten.rsqrt.default,
198201
exir_ops.edge.aten._softmax.default,
199202
exir_ops.edge.aten.select_copy.int,
@@ -256,6 +259,7 @@ def is_node_supported(
256259
exir_ops.edge.aten.var.correction,
257260
exir_ops.edge.aten.var.dim,
258261
exir_ops.edge.aten.add.Scalar,
262+
exir_ops.edge.aten.sqrt.default,
259263
exir_ops.edge.aten.sub.Scalar,
260264
exir_ops.edge.aten.mul.Scalar,
261265
exir_ops.edge.aten.div.Scalar,

backends/arm/operators/op_clamp.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def cast_type(value: Any) -> int | float:
6363
# Attempt to cast to float
6464
return float(value)
6565

66-
assert 2 <= len(node.args) <= 3
66+
if len(node.args) != 2 and len(node.args) != 3:
67+
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")
6768

6869
min_arg = dtype_min
6970
max_arg = dtype_max
@@ -84,7 +85,10 @@ def define_node(
8485
inputs: List[TosaArg],
8586
output: TosaArg,
8687
) -> None:
87-
assert len(node.all_input_nodes) == 1
88+
if len(node.all_input_nodes) != 1:
89+
raise ValueError(
90+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
91+
)
8892

8993
min_int8, max_int8 = self._get_min_max_arguments(
9094
node,
@@ -122,7 +126,10 @@ def define_node(
122126
inputs: List[TosaArg],
123127
output: TosaArg,
124128
) -> None:
125-
assert len(node.all_input_nodes) == 1
129+
if len(node.all_input_nodes) != 1:
130+
raise ValueError(
131+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
132+
)
126133

127134
if inputs[0].dtype == ts.DType.INT8:
128135
# Call the inherited define_node for handling integers

backends/arm/operators/op_minimum.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,27 @@ def define_node(
3737
inputs: List[TosaArg],
3838
output: TosaArg,
3939
) -> None:
40-
assert inputs[0].dtype == inputs[1].dtype
40+
if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype:
41+
raise TypeError(
42+
f"Data type of inputs and output must be the same. Got input 0 dtype: "
43+
f"{inputs[0].dtype}, input 1 dtype: {inputs[1].dtype} and output "
44+
f"dtype: {output.dtype}"
45+
)
4146

4247
scale_back = 1.0
4348
min_output = output
4449
if inputs[0].dtype == ts.DType.INT8:
4550
input_qparams = get_input_qparams(node)
46-
assert (
47-
len(input_qparams) == 2
48-
), f"Both inputs needs to have quantization information for {node}"
49-
# insert RESCALEs to int32
50-
assert (
51-
input_qparams[0] == input_qparams[1]
52-
), "Both inputs must have same quantization for MIN"
51+
if len(input_qparams) != 2:
52+
raise ValueError(
53+
f"Both inputs need to have quantization information for {node}"
54+
)
55+
if input_qparams[0] != input_qparams[1]:
56+
raise ValueError(
57+
"Both inputs must have the same quantization parameters for MIN"
58+
)
5359

60+
# insert RESCALEs to int32
5461
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5562
tosa_graph, inputs, node
5663
)

backends/arm/process_node.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_specification import TosaSpecification
1717
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
18+
from torch._export.utils import (
19+
get_buffer,
20+
get_lifted_tensor_constant,
21+
get_param,
22+
is_buffer,
23+
is_lifted_tensor_constant,
24+
is_param,
25+
)
1826
from torch.export.exported_program import ExportedProgram
1927

2028

@@ -99,8 +107,7 @@ def process_inputs_to_parameters(
99107
f"Failed processing parameter placeholder: {node.name}. "
100108
"Is the original torch function supported?"
101109
) from e
102-
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
103-
parameter_data = edge_program.state_dict[parameter_name]
110+
parameter_data = get_param(edge_program, node)
104111

105112
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
106113
parameter_values = parameter_data.detach().numpy()
@@ -128,8 +135,7 @@ def process_inputs_to_buffers(
128135
f"Failed processing buffer placeholder: {node.name}. "
129136
"Is the original torch function supported?"
130137
) from e
131-
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
132-
buffer_data = edge_program.state_dict[buffer_name]
138+
buffer_data = get_buffer(edge_program, node)
133139

134140
assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
135141
buffer_values = buffer_data.detach().numpy()
@@ -156,11 +162,8 @@ def process_inputs_to_lifted_tensor_constants(
156162
f"Failed processing lifted tensor constant placeholder: {node.name}. "
157163
"Is the original torch function supported?"
158164
) from e
159-
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
160-
tosa_arg.name
161-
]
162-
tensor = edge_program.tensor_constants[tensor_name]
163-
tensor_data = tensor.detach().numpy()
165+
tensor = get_lifted_tensor_constant(edge_program, node)
166+
tensor_data = tensor.detach().numpy() # type: ignore[union-attr]
164167

165168
tosa_graph.addConst(
166169
tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name
@@ -179,11 +182,11 @@ def process_placeholder(
179182

180183
if node.name in edge_program.graph_signature.user_inputs:
181184
process_inputs(node, tosa_graph, tosa_spec)
182-
elif node.name in edge_program.graph_signature.inputs_to_parameters:
185+
elif is_param(edge_program, node):
183186
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
184-
elif node.name in edge_program.graph_signature.inputs_to_buffers:
187+
elif is_buffer(edge_program, node):
185188
process_inputs_to_buffers(node, tosa_graph, edge_program)
186-
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
189+
elif is_lifted_tensor_constant(edge_program, node):
187190
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
188191
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
189192
raise NotImplementedError(
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
from executorch.backends.arm.test.common import parametrize
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
TosaPipelineBI,
12+
TosaPipelineMI,
13+
)
14+
15+
16+
class NonPersistentBuffer(nn.Module):
17+
"""
18+
Min code version registering a non-persistent input buffer.
19+
"""
20+
21+
def __init__(self):
22+
super().__init__()
23+
self.register_buffer("test_buff", torch.rand(2, 2, 2, 2), persistent=False)
24+
25+
def forward(self, x):
26+
return x - self.test_buff
27+
28+
29+
test_input = {"input": (torch.ones(2, 2, 2, 2),)}
30+
31+
input_t = tuple[torch.Tensor]
32+
33+
34+
@parametrize("test_data", test_input)
35+
def test_non_persistent_buffer_MI(test_data: input_t):
36+
"""
37+
Test validates Arm backend handling of non-persistent buffers
38+
and ensures that there are no asserts or errors when they are used.
39+
"""
40+
TosaPipelineMI[input_t](NonPersistentBuffer(), test_data, "").run()
41+
42+
43+
@parametrize("test_data", test_input)
44+
def test_non_persistent_buffer_BI(test_data: input_t):
45+
"""
46+
Test validates Arm backend handling of non-persistent buffers
47+
and ensures that there are no asserts or errors when they are used.
48+
"""
49+
TosaPipelineBI[input_t](NonPersistentBuffer(), test_data, "").run()

backends/arm/test/models/test_llama.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ def prepare_model(self):
7979

8080
llama_model, llama_inputs, llama_meta = get_llama_model(args)
8181

82-
# TODO: Remove workaround since attention mask should not be persistent,
83-
# it only works if input shape is always the same
84-
freqs_c = "freqs_cos"
85-
freqs_s = "freqs_sin"
86-
for i in range(llama_model.n_layers):
87-
val = llama_model.layers[i].attention.get_buffer("mask")
88-
llama_model.layers[i].attention.register_buffer(
89-
"mask", val, persistent=True
90-
)
91-
val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)
92-
llama_model.layers[i].attention.rope.register_buffer(
93-
freqs_c, val, persistent=True
94-
)
95-
val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)
96-
llama_model.layers[i].attention.rope.register_buffer(
97-
freqs_s, val, persistent=True
98-
)
99-
10082
return llama_model, llama_inputs, llama_meta
10183

10284
def test_llama_tosa_MI(self):

0 commit comments

Comments
 (0)