Skip to content

Commit bcf4b46

Browse files
authored
Arm backend: Add additional unsupported checks to Ethos-U55 backend (#9796)
- Add check for unsupported dtypes on Ethos-U55 - Add Ethos-U55 permute check - Move all Ethos-U55 support checks into a single file. Signed-off-by: Erik Lundell <[email protected]>
1 parent 5f01843 commit bcf4b46

File tree

7 files changed

+352
-143
lines changed

7 files changed

+352
-143
lines changed

backends/arm/operator_support/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import ( # noqa
99
convolution_support,
10+
ethos_u55_support,
1011
minmax_support,
1112
pool_2d_support,
1213
reduce_sum_support,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import typing
9+
10+
import torch
11+
import torch.fx as fx
12+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
13+
from executorch.backends.arm._passes.insert_table_ops import TableOps
14+
from executorch.backends.arm.operators.op_permute import transform_permutation_vector
15+
from executorch.backends.arm.tosa_utils import tosa_shape
16+
from executorch.exir.backend.utils import WhyNoPartitionReporter
17+
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from torch.fx.passes.operator_support import OperatorSupportBase
20+
21+
22+
def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:
23+
dtype = get_first_fake_tensor(node).dtype
24+
if not dtype.is_floating_point:
25+
return dtype
26+
if node.target is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default:
27+
return get_first_fake_tensor(node.all_input_nodes[0]).dtype
28+
q_node = list(node.users)[0]
29+
if q_node.target is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default:
30+
return typing.cast(torch.dtype, q_node.args[-1])
31+
# We can't easily figure out dtype, return None
32+
return None
33+
34+
35+
class EthosU55DtypeSupport(OperatorSupportBase):
36+
37+
def __init__(self, reporter: WhyNoPartitionReporter):
38+
super().__init__()
39+
self.reporter = reporter
40+
41+
targeted_ops_i8_i16_i32 = [
42+
exir_ops.edge.aten.cat.default,
43+
exir_ops.edge.aten.repeat.default,
44+
exir_ops.edge.aten.constant_pad_nd.default,
45+
exir_ops.edge.aten.view.default,
46+
exir_ops.edge.aten.permute.default,
47+
]
48+
49+
target_ops_i8 = tuple(TableOps.included_ops())
50+
51+
def is_node_supported( # noqa: C901
52+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
53+
) -> bool:
54+
55+
dtype = _try_determine_dtype(node)
56+
if dtype is None:
57+
# If we couldn't determine dtype, just return ok.
58+
return True
59+
60+
if node.target in self.targeted_ops_i8_i16_i32:
61+
if dtype not in (torch.int8, torch.int16, torch.int32):
62+
self.reporter.report_reject(
63+
node, f"Unsupported dtype {dtype} (Supports i8, i16, i32)."
64+
)
65+
return False
66+
67+
if node.target in self.target_ops_i8:
68+
if dtype not in (torch.int8,):
69+
self.reporter.report_reject(
70+
node, f"Unsupported dtype {dtype} (Supports i8)."
71+
)
72+
return False
73+
74+
if node.target == exir_ops.edge.aten.convolution.default:
75+
ifm, weight = node.all_input_nodes[0:2]
76+
ifm_dtype = _try_determine_dtype(ifm)
77+
if ifm_dtype is not None and ifm_dtype not in (torch.int8, torch.int16):
78+
self.reporter.report_reject(
79+
node, f"Unsupported input dtype {dtype} (Supports i8, i16)."
80+
)
81+
return False
82+
weight_dtype = _try_determine_dtype(weight)
83+
if weight_dtype is not None and weight_dtype not in (torch.int8,):
84+
self.reporter.report_reject(
85+
node, f"Unsupported weight dtype {dtype} (Supports i8)."
86+
)
87+
return False
88+
if len(node.all_input_nodes) > 2:
89+
bias = node.all_input_nodes[2]
90+
bias_dtype = _try_determine_dtype(bias)
91+
if bias_dtype is not None and bias_dtype not in (torch.int32,):
92+
self.reporter.report_reject(
93+
node, f"Unsupported bias dtype {dtype} (Supports i32)."
94+
)
95+
return False
96+
97+
if node.target in (
98+
exir_ops.edge.aten.mm.default,
99+
exir_ops.edge.aten.bmm.default,
100+
):
101+
for input_node in node.all_input_nodes:
102+
dtype = _try_determine_dtype(input_node)
103+
if dtype is not None and dtype != torch.int8:
104+
self.reporter.report_reject(
105+
input_node,
106+
f"Input {input_node.name} has unsupported dtype {dtype} (Supports i8).",
107+
)
108+
return False
109+
110+
return True
111+
112+
113+
class EthosU55NotSupported(OperatorSupportBase):
114+
"""
115+
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
116+
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
117+
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
118+
"""
119+
120+
unsupported_ops = [
121+
exir_ops.edge.aten.any.default, # REDUCE_ANY
122+
exir_ops.edge.aten.any.dim, # REDUCE_ANY
123+
exir_ops.edge.aten.any.dims, # REDUCE_ANY
124+
exir_ops.edge.aten.bitwise_and.Tensor,
125+
exir_ops.edge.aten.bitwise_or.Tensor,
126+
exir_ops.edge.aten.bitwise_xor.Tensor,
127+
exir_ops.edge.aten.bitwise_not,
128+
exir_ops.edge.aten.logical_and.default,
129+
exir_ops.edge.aten.logical_or.default,
130+
exir_ops.edge.aten.logical_xor.default,
131+
exir_ops.edge.aten.logical_not.default,
132+
exir_ops.edge.aten.amax.default, # REDUCE_MAX
133+
exir_ops.edge.aten.amin.default, # REDUCE_MIN
134+
exir_ops.edge.aten.eq.Tensor,
135+
exir_ops.edge.aten.eq.Scalar,
136+
exir_ops.edge.aten.ge.Tensor,
137+
exir_ops.edge.aten.gt.Tensor,
138+
exir_ops.edge.aten.le.Tensor,
139+
exir_ops.edge.aten.lt.Tensor,
140+
exir_ops.edge.aten.flip.default, # REVERSE
141+
exir_ops.edge.aten.grid_sampler_2d, # GATHER
142+
exir_ops.edge.aten.scatter.src,
143+
exir_ops.edge.aten.scatter.value,
144+
exir_ops.edge.aten.select_scatter.default,
145+
exir_ops.edge.aten.scatter_reduce.two,
146+
exir_ops.edge.aten.scatter_add.default,
147+
exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE
148+
exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE
149+
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
150+
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
151+
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
152+
]
153+
154+
def __init__(self, reporter: WhyNoPartitionReporter):
155+
self.reporter = reporter
156+
157+
def is_node_supported(
158+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
159+
) -> bool:
160+
161+
if node.target in self.unsupported_ops:
162+
self.reporter.report_reject(node, "Op is not supported on U55.")
163+
return False
164+
165+
return True
166+
167+
168+
shape_t = list[int]
169+
170+
171+
class EthosU55TransposeCheck(OperatorSupportBase):
172+
173+
def __init__(self, reporter: WhyNoPartitionReporter):
174+
super().__init__()
175+
self.reporter = reporter
176+
177+
def _pad_to_rank_4(
178+
self, shape: shape_t, permutation: list[int]
179+
) -> tuple[shape_t, shape_t]:
180+
diff = 4 - len(shape)
181+
padded_shape = [1] * diff + shape
182+
for i in range(len(permutation)):
183+
permutation[i] += diff
184+
padded_permutation = list(range(diff)) + permutation
185+
return padded_shape, padded_permutation
186+
187+
def axes_product(self, nhwc_shape: shape_t) -> int:
188+
product = 1
189+
for axes in nhwc_shape:
190+
product *= axes
191+
return product
192+
193+
def _permute_constraint_i8_i16(
194+
self, nhwc_shape: list[int], permutation: list[int]
195+
) -> bool:
196+
"""Returns True if the constraints are ok."""
197+
N, H, W, C = nhwc_shape
198+
match permutation:
199+
case (0, 1, 2, 3): # NHWC -> NHWC
200+
return True
201+
case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH
202+
return N * H <= 65536 and W <= 65536 and C <= 65536
203+
case _:
204+
return self.axes_product(nhwc_shape) <= 65536
205+
206+
def _permute_constraint_i32(
207+
self, nhwc_shape: list[int], permutation: list[int]
208+
) -> bool:
209+
"""Returns True if the constraints are ok."""
210+
N, H, W, C = nhwc_shape
211+
match permutation:
212+
case (0, 1, 2, 3): # NHWC -> NHWC
213+
return C <= 32768
214+
case (0, 2, 1, 3): # NHWC -> NHWC
215+
return N == 1 and H <= 65536 and W <= 65536 and C <= 16384
216+
case (0, 1, 3, 2): # NHWC -> NHCW
217+
return N * H <= 65536 and W <= 65536 and C <= 65536
218+
case _:
219+
return False
220+
221+
def _permute_constraint(self, shape, permutation, dtype):
222+
if dtype in (torch.int8, torch.int16):
223+
return self._permute_constraint_i8_i16(shape, permutation)
224+
if dtype == torch.int32:
225+
return not self._permute_constraint_i32(shape, permutation)
226+
return True
227+
228+
def is_node_supported(
229+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
230+
) -> bool:
231+
232+
if not node.target == exir_ops.edge.aten.permute_copy.default:
233+
return True
234+
235+
shape = list(get_first_fake_tensor(node).shape)
236+
dtype = _try_determine_dtype(node)
237+
permutation = list(typing.cast(list[int], node.args[1]))
238+
239+
rank = len(shape)
240+
if rank > 4:
241+
if dtype == torch.int32:
242+
self.reporter.report_reject(
243+
node, f"No support for {permutation=} in int32."
244+
)
245+
return False
246+
if dtype in (torch.int8, torch.int16):
247+
if self.axes_product(shape) > 65536:
248+
self.reporter.report_reject(
249+
node,
250+
f"No support for {shape=}, {dtype=}. Product of axes must be <65536",
251+
)
252+
return False
253+
return True
254+
255+
shape, permutation = self._pad_to_rank_4(shape, permutation)
256+
if rank == 3 or rank == 4:
257+
# For rank 3 and 4, we can have channels first or channels last dim order.
258+
# Since we don't know which at partition-time, test both.
259+
260+
nhwc_shape = tosa_shape(shape, [0, 2, 3, 1])
261+
nhwc_permutation = transform_permutation_vector(permutation, [0, 2, 3, 1])
262+
263+
if not self._permute_constraint(nhwc_shape, nhwc_permutation, dtype):
264+
self.reporter.report_reject(
265+
node,
266+
f"Unsupported NHWC {nhwc_shape=} for {nhwc_permutation=}, {dtype=}",
267+
)
268+
return False
269+
270+
if not self._permute_constraint(shape, permutation, dtype):
271+
self.reporter.report_reject(
272+
node, f"Unsupported NCHW {shape=} for {permutation=}, {dtype=}"
273+
)
274+
return False
275+
276+
return True

backends/arm/operator_support/tosa_supported_operators.py

+7-55
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
1919
FuseQuantizedActivationPass,
2020
)
21+
from executorch.backends.arm.operator_support.ethos_u55_support import (
22+
EthosU55DtypeSupport,
23+
EthosU55NotSupported,
24+
EthosU55TransposeCheck,
25+
)
2126
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
2227
from executorch.exir import ExportedProgram
2328
from executorch.exir.backend.utils import WhyNoPartitionReporter
@@ -118,6 +123,8 @@ def tosa_support_factory(
118123
negative_checks.append(CheckProperQuantization(reporter))
119124
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
120125
negative_checks.append(EthosU55NotSupported(reporter))
126+
negative_checks.append(EthosU55DtypeSupport(reporter))
127+
negative_checks.append(EthosU55TransposeCheck(reporter))
121128

122129
return chain(
123130
reporter.wrap_check(
@@ -216,61 +223,6 @@ def is_node_supported(
216223
return supported
217224

218225

219-
class EthosU55NotSupported(OperatorSupportBase):
220-
"""
221-
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
222-
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
223-
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
224-
"""
225-
226-
unsupported_ops = [
227-
exir_ops.edge.aten.any.default, # REDUCE_ANY
228-
exir_ops.edge.aten.any.dim, # REDUCE_ANY
229-
exir_ops.edge.aten.any.dims, # REDUCE_ANY
230-
exir_ops.edge.aten.bitwise_and.Tensor,
231-
exir_ops.edge.aten.bitwise_or.Tensor,
232-
exir_ops.edge.aten.bitwise_xor.Tensor,
233-
exir_ops.edge.aten.bitwise_not,
234-
exir_ops.edge.aten.logical_and.default,
235-
exir_ops.edge.aten.logical_or.default,
236-
exir_ops.edge.aten.logical_xor.default,
237-
exir_ops.edge.aten.logical_not.default,
238-
exir_ops.edge.aten.amax.default, # REDUCE_MAX
239-
exir_ops.edge.aten.amin.default, # REDUCE_MIN
240-
exir_ops.edge.aten.eq.Tensor,
241-
exir_ops.edge.aten.eq.Scalar,
242-
exir_ops.edge.aten.ge.Tensor,
243-
exir_ops.edge.aten.gt.Tensor,
244-
exir_ops.edge.aten.le.Tensor,
245-
exir_ops.edge.aten.lt.Tensor,
246-
exir_ops.edge.aten.flip.default, # REVERSE
247-
exir_ops.edge.aten.grid_sampler_2d, # GATHER
248-
exir_ops.edge.aten.scatter.src,
249-
exir_ops.edge.aten.scatter.value,
250-
exir_ops.edge.aten.select_scatter.default,
251-
exir_ops.edge.aten.scatter_reduce.two,
252-
exir_ops.edge.aten.scatter_add.default,
253-
exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE
254-
exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE
255-
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
256-
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
257-
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
258-
]
259-
260-
def __init__(self, reporter: WhyNoPartitionReporter):
261-
self.reporter = reporter
262-
263-
def is_node_supported(
264-
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
265-
) -> bool:
266-
267-
if node.target in self.unsupported_ops:
268-
self.reporter.report_reject(node, "Op is not supported on U55.")
269-
return False
270-
271-
return True
272-
273-
274226
class NeedsDecompositionCheck(OperatorSupportBase):
275227
"""
276228
Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding

0 commit comments

Comments
 (0)