Skip to content

Commit bbdfebb

Browse files
authored
Merge branch 'main' into intel-mac-check
2 parents 4264ef6 + 4f748fe commit bbdfebb

21 files changed

+806
-89
lines changed

Diff for: backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .insert_table_ops import InsertTableOpsPass # noqa
4040
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
4141
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
42+
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
4243
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
4344
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
4445
from .remove_clone_pass import RemoveClonePass # noqa

Diff for: backends/arm/_passes/arm_pass_manager.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
InsertTableOpsPass,
4141
KeepDimsFalseToSqueezePass,
4242
MatchArgRanksPass,
43+
MatchWhereSelfDtypePass,
4344
QuantizeOperatorArguments,
4445
RemoveClonePass,
4546
ReplaceScalarWithTensorArgPassTOSABI,
@@ -80,6 +81,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8081
self.add_pass(ConvertToClampPass())
8182
self.add_pass(ConvertMinMaxPass())
8283
self.add_pass(ConvertAnyDefaultDimDimsPass())
84+
self.add_pass(MatchWhereSelfDtypePass())
8385
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
8486
self.add_pass(CastToInt32Pass())
8587

@@ -130,6 +132,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
130132
self.add_pass(ConvertToClampPass())
131133
self.add_pass(ConvertMinMaxPass())
132134
self.add_pass(ConvertAnyDefaultDimDimsPass())
135+
self.add_pass(MatchWhereSelfDtypePass())
133136

134137
self.add_pass(AnnotateDecomposedMatmulPass())
135138
self.add_pass(QuantizeOperatorArguments())

Diff for: backends/arm/_passes/arm_pass_utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
import torch.fx
16+
from executorch.backends.arm.tosa_utils import get_node_debug_info
1617
from executorch.exir import ExportedProgram
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819

@@ -169,9 +170,13 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
169170
else:
170171
fake_tensor = node.meta["val"]
171172

172-
assert isinstance(
173-
fake_tensor, FakeTensor
174-
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
173+
if not isinstance(fake_tensor, FakeTensor):
174+
raise TypeError(
175+
f'Expected a FakeTensor in meta["val"] of node {node}, but got '
176+
f"{type(fake_tensor).__name__}\n"
177+
f"{get_node_debug_info(node)}"
178+
)
179+
175180
return fake_tensor
176181

177182

Diff for: backends/arm/_passes/match_arg_ranks_pass.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, exported_program):
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
5151
exir_ops.edge.aten.pow.Tensor_Tensor,
52+
exir_ops.edge.aten.where.self,
5253
]
5354

5455
def _match_op_rank(self, graph_module, node, arg, max_rank):
+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
DTYPE_RANK = {
12+
torch.bool: 0,
13+
torch.uint8: 1,
14+
torch.int8: 2,
15+
torch.int16: 3,
16+
torch.int32: 4,
17+
torch.int64: 5,
18+
torch.float16: 6,
19+
torch.float32: 7,
20+
torch.float64: 8,
21+
}
22+
23+
24+
def get_largest_dtype(dtype_1, dtype_2):
25+
"""Find the largest dtype."""
26+
return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2
27+
28+
29+
class MatchWhereSelfDtypePass(ExportPass):
30+
"""Pass to match data types of non-condition input tensors.
31+
32+
Edge dialect allows different data types for non-condition tensors, while TOSA
33+
does not. In cases where they differ a TOSA CAST operator is inserted.
34+
35+
There is an edge case where one input is `boolean`, which cannot be directly cast
36+
to, for example, float32. When this occurs two CAST operators are added to first
37+
cast to int8 and then to the correct target data type.
38+
39+
"""
40+
41+
def call(self, graph_module: torch.fx.GraphModule):
42+
modified_graph = False
43+
graph = graph_module.graph
44+
node_list = graph.find_nodes(
45+
op="call_function", target=exir_ops.edge.aten.where.self
46+
)
47+
for node in node_list:
48+
cond, input_, other_ = node.args
49+
50+
input_dtype = input_.meta["val"].dtype
51+
other_dtype = other_.meta["val"].dtype
52+
target_dtype = torch.float32
53+
if input_dtype != other_dtype:
54+
target_dtype = get_largest_dtype(input_dtype, other_dtype)
55+
56+
for arg in node.args[1:]:
57+
arg_dtype = arg.meta["val"].dtype
58+
59+
if arg_dtype != target_dtype:
60+
if arg_dtype == torch.bool:
61+
# Bool is an edge case which cannot necessarily be directly
62+
# converted to the target data type.
63+
with graph.inserting_after(arg):
64+
replace_node_int8 = create_node(
65+
graph,
66+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
67+
)
68+
replace_node_int8.args = (arg,)
69+
replace_node_int8.kwargs = {"dtype": torch.int8}
70+
71+
with graph.inserting_after(replace_node_int8):
72+
replace_node_fp32 = create_node(
73+
graph,
74+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
75+
)
76+
replace_node_fp32.args = (replace_node_int8,)
77+
replace_node_fp32.kwargs = {"dtype": target_dtype}
78+
node.replace_input_with(arg, replace_node_fp32)
79+
else:
80+
with graph.inserting_after(arg):
81+
replace_node = create_node(
82+
graph,
83+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
84+
)
85+
replace_node.args = (arg,)
86+
replace_node.kwargs = {"dtype": target_dtype}
87+
node.replace_input_with(arg, replace_node)
88+
89+
modified_graph = True
90+
91+
if modified_graph:
92+
graph_module.recompile()
93+
graph_module = super().call(graph_module).graph_module
94+
95+
return PassResult(graph_module, modified_graph)

Diff for: backends/arm/operator_support/ethos_u55_support.py

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class EthosU55NotSupported(OperatorSupportBase):
149149
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
150150
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
151151
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
152+
exir_ops.edge.aten.where.self, # SELECT
152153
]
153154

154155
def __init__(self, reporter: WhyNoPartitionReporter):

Diff for: backends/arm/operator_support/tosa_supported_operators.py

+1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def is_node_supported(
207207
exir_ops.edge.aten.squeeze_copy.dims,
208208
exir_ops.edge.aten.pow.Tensor_Scalar,
209209
exir_ops.edge.aten.pow.Tensor_Tensor,
210+
exir_ops.edge.aten.where.self,
210211
operator.getitem,
211212
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
212213
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,

Diff for: backends/arm/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
op_transpose,
5050
op_upsample_nearest2d,
5151
op_view,
52+
op_where,
5253
ops_binary,
5354
ops_unary,
5455
)

Diff for: backends/arm/operators/op_where.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import List, Sequence
7+
8+
import serializer.tosa_serializer as ts # type: ignore
9+
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from serializer.tosa_serializer import TosaOp
17+
from torch.fx import Node
18+
19+
20+
def _add_node_to_tosa_graph(
21+
tosa_graph: ts.TosaSerializer,
22+
inputs: List[TosaArg],
23+
output: TosaArg,
24+
supported_dtypes: Sequence,
25+
) -> None:
26+
if len(inputs) != 3:
27+
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
28+
29+
if inputs[0].dtype is not ts.DType.BOOL:
30+
raise ValueError("Input 0 needs to have dtype BOOL")
31+
if inputs[1].dtype != inputs[2].dtype:
32+
raise ValueError(
33+
"Non-condition tensors must have same data type, got "
34+
f"{inputs[1].dtype} and {inputs[2].dtype}"
35+
)
36+
for input_ in inputs[1:]:
37+
if input_.dtype not in supported_dtypes:
38+
raise ValueError(
39+
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
40+
)
41+
42+
tosa_graph.addOperator(
43+
TosaOp.Op().SELECT,
44+
[inputs[0].name, inputs[1].name, inputs[2].name],
45+
[output.name],
46+
None,
47+
)
48+
49+
50+
@register_node_visitor
51+
class WhereVisitor_080_BI(NodeVisitor):
52+
target = "aten.where.self"
53+
54+
tosa_specs = [
55+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
56+
]
57+
58+
def __init__(self, *args):
59+
super().__init__(*args)
60+
61+
def define_node(
62+
self,
63+
node: Node,
64+
tosa_graph: ts.TosaSerializer,
65+
inputs: List[TosaArg],
66+
output: TosaArg,
67+
) -> None:
68+
69+
bi_supported_dtypes = [
70+
ts.DType.INT8,
71+
ts.DType.INT16,
72+
ts.DType.INT32,
73+
ts.DType.BOOL,
74+
]
75+
_add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
76+
77+
78+
@register_node_visitor
79+
class WhereVisitor_080_MI(WhereVisitor_080_BI):
80+
81+
tosa_specs = [
82+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
83+
]
84+
85+
def __init__(self, *args):
86+
super().__init__(*args)
87+
88+
def define_node(
89+
self,
90+
node: Node,
91+
tosa_graph: ts.TosaSerializer,
92+
inputs: List[TosaArg],
93+
output: TosaArg,
94+
) -> None:
95+
mi_supported_dtypes = [
96+
ts.DType.FP16,
97+
ts.DType.FP32,
98+
ts.DType.INT8,
99+
ts.DType.INT16,
100+
ts.DType.INT32,
101+
ts.DType.BOOL,
102+
]
103+
_add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)

Diff for: backends/arm/quantizer/quantization_annotator.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,14 @@ def _match_pattern(
238238
torch.ops.aten.dropout_.default,
239239
torch.ops.aten.clamp.default,
240240
torch.ops.aten.clamp.Tensor,
241+
torch.ops.aten.where,
241242
operator.getitem,
242243
]
243244

244245

245246
def get_quant_properties( # noqa: C901
246247
node: Node, gm: torch.fx.GraphModule, quantization_config
247-
) -> _OpQuantProperties:
248+
) -> _OpQuantProperties | None:
248249
input_act_qspec = quantization_config.get_input_act_qspec()
249250
weight_qspec = quantization_config.get_weight_qspec()
250251
output_act_qspec = quantization_config.get_output_act_qspec()
@@ -322,6 +323,13 @@ def any_or_hardtanh_min_zero(n: Node):
322323
),
323324
]
324325
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
326+
elif node.target in (torch.ops.aten.where.self,):
327+
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
328+
quant_properties.quant_inputs = [
329+
_QuantProperty(1, shared_qspec), # type: ignore[arg-type]
330+
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
331+
]
332+
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
325333
elif node.target == torch.ops.aten.adaptive_avg_pool2d.default:
326334
input_qspec = (
327335
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
@@ -376,16 +384,16 @@ def any_or_hardtanh_min_zero(n: Node):
376384
quant_properties.quant_output = None
377385
elif node.target in _parent_shared_qspec:
378386
if not isinstance(node.args[0], Node):
379-
return None # type: ignore[return-value]
387+
return None
380388

381389
if not arm_quantizer_utils.is_output_annotated(node.args[0]): # type: ignore[attr-defined]
382-
return None # type: ignore[return-value]
390+
return None
383391

384392
shared_qspec = SharedQuantizationSpec(node.args[0])
385393
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
386394
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
387395
else:
388-
return None # type: ignore[return-value]
396+
return None
389397

390398
# Don't check if operator.getitem is ok for quantization, it's always ok
391399
if node.target == operator.getitem:
@@ -394,7 +402,7 @@ def any_or_hardtanh_min_zero(n: Node):
394402
# Check that each inputs/outputs can be quantized properly with the
395403
# provided quantization properties.
396404
if not _is_ok_for_quantization(node, quant_properties, gm):
397-
return None # type: ignore[return-value]
405+
return None
398406

399407
return quant_properties
400408

0 commit comments

Comments
 (0)