Skip to content

Arm backend: Add additional unsupported checks to Ethos-U55 backend #9796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import ( # noqa
convolution_support,
ethos_u55_support,
minmax_support,
pool_2d_support,
reduce_sum_support,
Expand Down
276 changes: 276 additions & 0 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import typing

import torch
import torch.fx as fx
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.operators.op_permute import transform_permutation_vector
from executorch.backends.arm.tosa_utils import tosa_shape
from executorch.exir.backend.utils import WhyNoPartitionReporter

from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.passes.operator_support import OperatorSupportBase


def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:
dtype = get_first_fake_tensor(node).dtype
if not dtype.is_floating_point:
return dtype
if node.target is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default:
return get_first_fake_tensor(node.all_input_nodes[0]).dtype
q_node = list(node.users)[0]
if q_node.target is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default:
return typing.cast(torch.dtype, q_node.args[-1])
# We can't easily figure out dtype, return None
return None


class EthosU55DtypeSupport(OperatorSupportBase):

def __init__(self, reporter: WhyNoPartitionReporter):
super().__init__()
self.reporter = reporter

targeted_ops_i8_i16_i32 = [
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.view.default,
exir_ops.edge.aten.permute.default,
]

target_ops_i8 = tuple(TableOps.included_ops())

def is_node_supported( # noqa: C901
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

dtype = _try_determine_dtype(node)
if dtype is None:
# If we couldn't determine dtype, just return ok.
return True

if node.target in self.targeted_ops_i8_i16_i32:
if dtype not in (torch.int8, torch.int16, torch.int32):
self.reporter.report_reject(
node, f"Unsupported dtype {dtype} (Supports i8, i16, i32)."
)
return False

if node.target in self.target_ops_i8:
if dtype not in (torch.int8,):
self.reporter.report_reject(
node, f"Unsupported dtype {dtype} (Supports i8)."
)
return False

if node.target == exir_ops.edge.aten.convolution.default:
ifm, weight = node.all_input_nodes[0:2]
ifm_dtype = _try_determine_dtype(ifm)
if ifm_dtype is not None and ifm_dtype not in (torch.int8, torch.int16):
self.reporter.report_reject(
node, f"Unsupported input dtype {dtype} (Supports i8, i16)."
)
return False
weight_dtype = _try_determine_dtype(weight)
if weight_dtype is not None and weight_dtype not in (torch.int8,):
self.reporter.report_reject(
node, f"Unsupported weight dtype {dtype} (Supports i8)."
)
return False
if len(node.all_input_nodes) > 2:
bias = node.all_input_nodes[2]
bias_dtype = _try_determine_dtype(bias)
if bias_dtype is not None and bias_dtype not in (torch.int32,):
self.reporter.report_reject(
node, f"Unsupported bias dtype {dtype} (Supports i32)."
)
return False

if node.target in (
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.bmm.default,
):
for input_node in node.all_input_nodes:
dtype = _try_determine_dtype(input_node)
if dtype is not None and dtype != torch.int8:
self.reporter.report_reject(
input_node,
f"Input {input_node.name} has unsupported dtype {dtype} (Supports i8).",
)
return False

return True


class EthosU55NotSupported(OperatorSupportBase):
"""
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
"""

unsupported_ops = [
exir_ops.edge.aten.any.default, # REDUCE_ANY
exir_ops.edge.aten.any.dim, # REDUCE_ANY
exir_ops.edge.aten.any.dims, # REDUCE_ANY
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.bitwise_not,
exir_ops.edge.aten.logical_and.default,
exir_ops.edge.aten.logical_or.default,
exir_ops.edge.aten.logical_xor.default,
exir_ops.edge.aten.logical_not.default,
exir_ops.edge.aten.amax.default, # REDUCE_MAX
exir_ops.edge.aten.amin.default, # REDUCE_MIN
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.flip.default, # REVERSE
exir_ops.edge.aten.grid_sampler_2d, # GATHER
exir_ops.edge.aten.scatter.src,
exir_ops.edge.aten.scatter.value,
exir_ops.edge.aten.select_scatter.default,
exir_ops.edge.aten.scatter_reduce.two,
exir_ops.edge.aten.scatter_add.default,
exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE
exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
]

def __init__(self, reporter: WhyNoPartitionReporter):
self.reporter = reporter

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

if node.target in self.unsupported_ops:
self.reporter.report_reject(node, "Op is not supported on U55.")
return False

return True


shape_t = list[int]


class EthosU55TransposeCheck(OperatorSupportBase):

def __init__(self, reporter: WhyNoPartitionReporter):
super().__init__()
self.reporter = reporter

def _pad_to_rank_4(
self, shape: shape_t, permutation: list[int]
) -> tuple[shape_t, shape_t]:
diff = 4 - len(shape)
padded_shape = [1] * diff + shape
for i in range(len(permutation)):
permutation[i] += diff
padded_permutation = list(range(diff)) + permutation
return padded_shape, padded_permutation

def axes_product(self, nhwc_shape: shape_t) -> int:
product = 1
for axes in nhwc_shape:
product *= axes
return product

def _permute_constraint_i8_i16(
self, nhwc_shape: list[int], permutation: list[int]
) -> bool:
"""Returns True if the constraints are ok."""
N, H, W, C = nhwc_shape
match permutation:
case (0, 1, 2, 3): # NHWC -> NHWC
return True
case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH
return N * H <= 65536 and W <= 65536 and C <= 65536
case _:
return self.axes_product(nhwc_shape) <= 65536

def _permute_constraint_i32(
self, nhwc_shape: list[int], permutation: list[int]
) -> bool:
"""Returns True if the constraints are ok."""
N, H, W, C = nhwc_shape
match permutation:
case (0, 1, 2, 3): # NHWC -> NHWC
return C <= 32768
case (0, 2, 1, 3): # NHWC -> NHWC
return N == 1 and H <= 65536 and W <= 65536 and C <= 16384
case (0, 1, 3, 2): # NHWC -> NHCW
return N * H <= 65536 and W <= 65536 and C <= 65536
case _:
return False

def _permute_constraint(self, shape, permutation, dtype):
if dtype in (torch.int8, torch.int16):
return self._permute_constraint_i8_i16(shape, permutation)
if dtype == torch.int32:
return not self._permute_constraint_i32(shape, permutation)
return True

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

if not node.target == exir_ops.edge.aten.permute_copy.default:
return True

shape = list(get_first_fake_tensor(node).shape)
dtype = _try_determine_dtype(node)
permutation = list(typing.cast(list[int], node.args[1]))

rank = len(shape)
if rank > 4:
if dtype == torch.int32:
self.reporter.report_reject(
node, f"No support for {permutation=} in int32."
)
return False
if dtype in (torch.int8, torch.int16):
if self.axes_product(shape) > 65536:
self.reporter.report_reject(
node,
f"No support for {shape=}, {dtype=}. Product of axes must be <65536",
)
return False
return True

shape, permutation = self._pad_to_rank_4(shape, permutation)
if rank == 3 or rank == 4:
# For rank 3 and 4, we can have channels first or channels last dim order.
# Since we don't know which at partition-time, test both.

nhwc_shape = tosa_shape(shape, [0, 2, 3, 1])
nhwc_permutation = transform_permutation_vector(permutation, [0, 2, 3, 1])

if not self._permute_constraint(nhwc_shape, nhwc_permutation, dtype):
self.reporter.report_reject(
node,
f"Unsupported NHWC {nhwc_shape=} for {nhwc_permutation=}, {dtype=}",
)
return False

if not self._permute_constraint(shape, permutation, dtype):
self.reporter.report_reject(
node, f"Unsupported NCHW {shape=} for {permutation=}, {dtype=}"
)
return False

return True
62 changes: 7 additions & 55 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
FuseQuantizedActivationPass,
)
from executorch.backends.arm.operator_support.ethos_u55_support import (
EthosU55DtypeSupport,
EthosU55NotSupported,
EthosU55TransposeCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
Expand Down Expand Up @@ -118,6 +123,8 @@ def tosa_support_factory(
negative_checks.append(CheckProperQuantization(reporter))
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
negative_checks.append(EthosU55NotSupported(reporter))
negative_checks.append(EthosU55DtypeSupport(reporter))
negative_checks.append(EthosU55TransposeCheck(reporter))

return chain(
reporter.wrap_check(
Expand Down Expand Up @@ -216,61 +223,6 @@ def is_node_supported(
return supported


class EthosU55NotSupported(OperatorSupportBase):
"""
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
"""

unsupported_ops = [
exir_ops.edge.aten.any.default, # REDUCE_ANY
exir_ops.edge.aten.any.dim, # REDUCE_ANY
exir_ops.edge.aten.any.dims, # REDUCE_ANY
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.bitwise_not,
exir_ops.edge.aten.logical_and.default,
exir_ops.edge.aten.logical_or.default,
exir_ops.edge.aten.logical_xor.default,
exir_ops.edge.aten.logical_not.default,
exir_ops.edge.aten.amax.default, # REDUCE_MAX
exir_ops.edge.aten.amin.default, # REDUCE_MIN
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.flip.default, # REVERSE
exir_ops.edge.aten.grid_sampler_2d, # GATHER
exir_ops.edge.aten.scatter.src,
exir_ops.edge.aten.scatter.value,
exir_ops.edge.aten.select_scatter.default,
exir_ops.edge.aten.scatter_reduce.two,
exir_ops.edge.aten.scatter_add.default,
exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE
exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
]

def __init__(self, reporter: WhyNoPartitionReporter):
self.reporter = reporter

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

if node.target in self.unsupported_ops:
self.reporter.report_reject(node, "Op is not supported on U55.")
return False

return True


class NeedsDecompositionCheck(OperatorSupportBase):
"""
Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding
Expand Down
Loading
Loading