Skip to content

Commit ffe9181

Browse files
committed
Arm backend: Add Ethos-U55 permute check
Signed-off-by: Erik Lundell <[email protected]> Change-Id: Id7c6d6469e96e4133b7b1a54be6ea66bc7dc861a
1 parent 7889c0f commit ffe9181

File tree

4 files changed

+185
-49
lines changed

4 files changed

+185
-49
lines changed

backends/arm/operator_support/ethos_u55_support.py

+131-31
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,27 @@
1111
import torch.fx as fx
1212
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1313
from executorch.backends.arm._passes.insert_table_ops import TableOps
14+
from executorch.backends.arm.operators.op_permute import transform_permutation_vector
15+
from executorch.backends.arm.tosa_utils import tosa_shape
1416
from executorch.exir.backend.utils import WhyNoPartitionReporter
1517

1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from torch.fx.passes.operator_support import OperatorSupportBase
1820

1921

22+
def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:
23+
dtype = get_first_fake_tensor(node).dtype
24+
if not dtype.is_floating_point:
25+
return dtype
26+
if node.target is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default:
27+
return get_first_fake_tensor(node.all_input_nodes[0]).dtype
28+
q_node = list(node.users)[0]
29+
if q_node.target is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default:
30+
return typing.cast(torch.dtype, q_node.args[-1])
31+
# We can't easily figure out dtype, return None
32+
return None
33+
34+
2035
class EthosU55DtypeSupport(OperatorSupportBase):
2136

2237
def __init__(self, reporter: WhyNoPartitionReporter):
@@ -33,37 +48,11 @@ def __init__(self, reporter: WhyNoPartitionReporter):
3348

3449
target_ops_i8 = tuple(TableOps.included_ops())
3550

36-
def _try_determine_dtype(self, node: fx.Node) -> torch.dtype | None:
37-
"""Attempt to figure out the quantized data type of node. On failure, return None."""
38-
39-
dtype = get_first_fake_tensor(node).dtype
40-
if not dtype.is_floating_point:
41-
return dtype
42-
43-
if (
44-
node.target
45-
is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
46-
):
47-
return get_first_fake_tensor(node.all_input_nodes[0]).dtype
48-
49-
if len(node.users) == 0:
50-
return None
51-
52-
q_node = list(node.users)[0]
53-
if (
54-
q_node.target
55-
is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
56-
):
57-
return typing.cast(torch.dtype, q_node.args[-1])
58-
59-
# We can't easily figure out dtype, return None
60-
return None
61-
6251
def is_node_supported( # noqa: C901
6352
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
6453
) -> bool:
6554

66-
dtype = self._try_determine_dtype(node)
55+
dtype = _try_determine_dtype(node)
6756
if dtype is None:
6857
# If we couldn't determine dtype, just return ok.
6958
return True
@@ -84,21 +73,21 @@ def is_node_supported( # noqa: C901
8473

8574
if node.target == exir_ops.edge.aten.convolution.default:
8675
ifm, weight = node.all_input_nodes[0:2]
87-
ifm_dtype = self._try_determine_dtype(ifm)
76+
ifm_dtype = _try_determine_dtype(ifm)
8877
if ifm_dtype is not None and ifm_dtype not in (torch.int8, torch.int16):
8978
self.reporter.report_reject(
9079
node, f"Unsupported input dtype {dtype} (Supports i8, i16)."
9180
)
9281
return False
93-
weight_dtype = self._try_determine_dtype(weight)
82+
weight_dtype = _try_determine_dtype(weight)
9483
if weight_dtype is not None and weight_dtype not in (torch.int8,):
9584
self.reporter.report_reject(
9685
node, f"Unsupported weight dtype {dtype} (Supports i8)."
9786
)
9887
return False
9988
if len(node.all_input_nodes) > 2:
10089
bias = node.all_input_nodes[2]
101-
bias_dtype = self._try_determine_dtype(bias)
90+
bias_dtype = _try_determine_dtype(bias)
10291
if bias_dtype is not None and bias_dtype not in (torch.int32,):
10392
self.reporter.report_reject(
10493
node, f"Unsupported bias dtype {dtype} (Supports i32)."
@@ -110,7 +99,7 @@ def is_node_supported( # noqa: C901
11099
exir_ops.edge.aten.bmm.default,
111100
):
112101
for input_node in node.all_input_nodes:
113-
dtype = self._try_determine_dtype(input_node)
102+
dtype = _try_determine_dtype(input_node)
114103
if dtype is not None and dtype != torch.int8:
115104
self.reporter.report_reject(
116105
input_node,
@@ -174,3 +163,114 @@ def is_node_supported(
174163
return False
175164

176165
return True
166+
167+
168+
shape_t = list[int]
169+
170+
171+
class EthosU55TransposeCheck(OperatorSupportBase):
172+
173+
def __init__(self, reporter: WhyNoPartitionReporter):
174+
super().__init__()
175+
self.reporter = reporter
176+
177+
def _pad_to_rank_4(
178+
self, shape: shape_t, permutation: list[int]
179+
) -> tuple[shape_t, shape_t]:
180+
diff = 4 - len(shape)
181+
padded_shape = [1] * diff + shape
182+
for i in range(len(permutation)):
183+
permutation[i] += diff
184+
padded_permutation = list(range(diff)) + permutation
185+
return padded_shape, padded_permutation
186+
187+
def axes_product(self, nhwc_shape: shape_t) -> int:
188+
product = 1
189+
for axes in nhwc_shape:
190+
product *= axes
191+
return product
192+
193+
def _permute_constraint_i8_i16(
194+
self, nhwc_shape: list[int], permutation: list[int]
195+
) -> bool:
196+
"""Returns True if the constraints are ok."""
197+
N, H, W, C = nhwc_shape
198+
match permutation:
199+
case (0, 1, 2, 3): # NHWC -> NHWC
200+
return True
201+
case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH
202+
return N * H <= 65536 and W <= 65536 and C <= 65536
203+
case _:
204+
return self.axes_product(nhwc_shape) <= 65536
205+
206+
def _permute_constraint_i32(
207+
self, nhwc_shape: list[int], permutation: list[int]
208+
) -> bool:
209+
"""Returns True if the constraints are ok."""
210+
N, H, W, C = nhwc_shape
211+
match permutation:
212+
case (0, 1, 2, 3): # NHWC -> NHWC
213+
return C <= 32768
214+
case (0, 2, 1, 3): # NHWC -> NHWC
215+
return N == 1 and H <= 65536 and W <= 65536 and C <= 16384
216+
case (0, 1, 3, 2): # NHWC -> NHCW
217+
return N * H <= 65536 and W <= 65536 and C <= 65536
218+
case _:
219+
return False
220+
221+
def _permute_constraint(self, shape, permutation, dtype):
222+
if dtype in (torch.int8, torch.int16):
223+
return self._permute_constraint_i8_i16(shape, permutation)
224+
if dtype == torch.int32:
225+
return not self._permute_constraint_i32(shape, permutation)
226+
return True
227+
228+
def is_node_supported(
229+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
230+
) -> bool:
231+
232+
if not node.target == exir_ops.edge.aten.permute_copy.default:
233+
return True
234+
235+
shape = list(get_first_fake_tensor(node).shape)
236+
dtype = _try_determine_dtype(node)
237+
permutation = list(typing.cast(list[int], node.args[1]))
238+
239+
rank = len(shape)
240+
if rank > 4:
241+
if dtype == torch.int32:
242+
self.reporter.report_reject(
243+
node, f"No support for {permutation=} in int32."
244+
)
245+
return False
246+
if dtype in (torch.int8, torch.int16):
247+
if self.axes_product(shape) > 65536:
248+
self.reporter.report_reject(
249+
node,
250+
f"No support for {shape=}, {dtype=}. Product of axes must be <65536",
251+
)
252+
return False
253+
return True
254+
255+
shape, permutation = self._pad_to_rank_4(shape, permutation)
256+
if rank == 3 or rank == 4:
257+
# For rank 3 and 4, we can have channels first or channels last dim order.
258+
# Since we don't know which at partition-time, test both.
259+
260+
nhwc_shape = tosa_shape(shape, [0, 2, 3, 1])
261+
nhwc_permutation = transform_permutation_vector(permutation, [0, 2, 3, 1])
262+
263+
if not self._permute_constraint(nhwc_shape, nhwc_permutation, dtype):
264+
self.reporter.report_reject(
265+
node,
266+
f"Unsupported NHWC {nhwc_shape=} for {nhwc_permutation=}, {dtype=}",
267+
)
268+
return False
269+
270+
if not self._permute_constraint(shape, permutation, dtype):
271+
self.reporter.report_reject(
272+
node, f"Unsupported NCHW {shape=} for {permutation=}, {dtype=}"
273+
)
274+
return False
275+
276+
return True

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.arm.operator_support.ethos_u55_support import (
2222
EthosU55DtypeSupport,
2323
EthosU55NotSupported,
24+
EthosU55TransposeCheck,
2425
)
2526
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
2627
from executorch.exir import ExportedProgram
@@ -123,6 +124,7 @@ def tosa_support_factory(
123124
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
124125
negative_checks.append(EthosU55NotSupported(reporter))
125126
negative_checks.append(EthosU55DtypeSupport(reporter))
127+
negative_checks.append(EthosU55TransposeCheck(reporter))
126128

127129
return chain(
128130
reporter.wrap_check(

backends/arm/operators/op_permute.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,29 @@ def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]:
6565
return p
6666

6767

68+
def transform_permutation_vector(permutation_vector: list[int], dim_order: list[int]):
69+
"""Transforms a permutation to dim_order."""
70+
71+
# We need to first transform to dim_order, apply the permutation P,
72+
# and then transform back to the original dim_order.
73+
# This transformation, S, is also a permutation, with the dim_order as permutation vector.
74+
75+
# To do this, represent P and S with permutation matrices.
76+
# Matrices can handle chained transformations and inversion easily.
77+
S = permutation_vector_to_matrix(dim_order)
78+
# The inverse of a permutation matrix is its transpose.
79+
S_inverse = S.t()
80+
P = permutation_vector_to_matrix(permutation_vector)
81+
82+
# The complete transformation is S * P * S_inverse.
83+
transformation_matrix = S.matmul(P.matmul(S_inverse))
84+
85+
# Luckily, since it is just a combination of permutations, the result is also a permutation
86+
# that can again be described by a new permutation vector.
87+
permutation_vector = permutation_matrix_to_vector(transformation_matrix)
88+
return permutation_vector
89+
90+
6891
@register_node_visitor
6992
class PermuteVisitor(NodeVisitor):
7093
target = "aten.permute_copy.default"
@@ -86,23 +109,10 @@ def define_node(
86109

87110
if output.dim_order != tuple(range(len(output.dim_order))):
88111
# the permutation vector can't be used directly if we are not in NCHW dim_order.
89-
# We need to first transform to NCHW, apply P,
90-
# and then transform back to the original dim_order.
91-
# This transformation, S, is also a permutation, with the dim_order as permutation vector.
92-
93-
# To do this, represent P and S with permutation matrices.
94-
# Matrices can handle chained transformations and inversion easily.
95-
S = permutation_vector_to_matrix(output.dim_order)
96-
# The inverse of a permutation matrix is its transpose.
97-
S_inverse = S.transpose(1, 0)
98-
P = permutation_vector_to_matrix(permutation_vector)
99-
100-
# The complete transformation is S * P * S_inverse.
101-
transformation_matrix = S.matmul(P.matmul(S_inverse))
102-
103-
# Luckily, since it is just a combination of permutations, the result is also a permutation
104-
# that can again be described by a new permutation vector.
105-
permutation_vector = permutation_matrix_to_vector(transformation_matrix)
112+
# Transform to dim_order.
113+
permutation_vector = transform_permutation_vector(
114+
permutation_vector, output.dim_order
115+
)
106116

107117
attr = ts.TosaSerializerAttribute()
108118
attr.TransposeAttribute(permutation_vector)

backends/arm/test/ops/test_permute.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -20,6 +20,7 @@
2020
)
2121
from executorch.backends.arm.test import common, conftest
2222
from executorch.backends.arm.test.tester.arm_tester import ArmTester
23+
from executorch.backends.arm.test.tester.test_pipeline import OpNotSupportedPipeline
2324
from executorch.backends.arm.tosa_specification import TosaSpecification
2425
from executorch.backends.xnnpack.test.tester.tester import Quantize
2526
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -163,3 +164,26 @@ def test_permute_u85_BI_xfails(
163164
self._test_permute_ethos_BI_pipeline(
164165
self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,)
165166
)
167+
168+
169+
reject_data_suite = {
170+
"int8_r3_axes_product": ([1, 700, 1000], [2, 1, 0], torch.int8),
171+
"int8_r5_axes_product": ([1, 1, 1, 700, 1000], [0, 1, 2, 3, 4], torch.int8),
172+
"int8_r4_NH_too_large": ([700, 100, 1, 1], [0, 1, 3, 2], torch.int8),
173+
"int32_r5_no_support": ([2, 2, 2, 2, 2], [3, 4, 2, 1, 0], torch.int32),
174+
}
175+
input_t = tuple[torch.Tensor]
176+
177+
178+
@common.parametrize("test_data", reject_data_suite)
179+
def test_permute_u55_BI_not_delegated(test_data):
180+
# Tests that we don't delegate these ops since they are not supported on U55.
181+
shape, permutation, dtype = test_data
182+
data = ((torch.rand(shape) * 10).to(dtype),)
183+
pipeline = OpNotSupportedPipeline[input_t](
184+
TestPermute.Permute(dims=permutation),
185+
data,
186+
"TOSA-0.80+BI+u55",
187+
{"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1},
188+
)
189+
pipeline.run()

0 commit comments

Comments
 (0)