Skip to content

Commit 242995d

Browse files
committed
Update
[ghstack-poisoned]
2 parents 3f1b775 + afad88e commit 242995d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+617
-295
lines changed

CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -161,7 +161,7 @@ if(OPTIMIZE_SIZE)
161161
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Os")
162162
else()
163163
# -O2: Moderate opt.
164-
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2")
164+
set(CMAKE_CXX_FLAGS_RELEASE "-O2 ${CMAKE_CXX_FLAGS_RELEASE}")
165165
endif()
166166

167167
option(EXECUTORCH_BUILD_ANDROID_JNI "Build Android JNI" OFF)

backends/apple/coreml/TARGETS

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ runtime.cxx_python_extension(
7272
headers = glob([
7373
"runtime/inmemoryfs/**/*.hpp",
7474
]),
75-
base_module = "",
75+
base_module = "executorch.backends.apple.coreml",
7676
compiler_flags = [
7777
"-std=c++17",
7878
],

backends/arm/_passes/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from . import arm_pass_utils # noqa
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10-
from .cast_int64_pass import CastInt64ToInt32Pass # noqa
10+
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
11+
from .cast_to_int32_pass import CastToInt32Pass # noqa
1112
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
1213
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
1314
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa

backends/arm/_passes/arm_pass_manager.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from executorch.backends.arm._passes import (
1111
AnnotateChannelsLastDimOrder,
1212
AnnotateDecomposedMatmulPass,
13-
CastInt64ToInt32Pass,
13+
CastInt64BuffersToInt32Pass,
14+
CastToInt32Pass,
1415
ComputeConstantOpsAOT,
1516
Conv1dUnsqueezePass,
1617
ConvertAnyDefaultDimDimsPass,
@@ -80,6 +81,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8081
self.add_pass(ConvertToClampPass())
8182
self.add_pass(ConvertMinMaxPass())
8283
self.add_pass(ConvertAnyDefaultDimDimsPass())
84+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
85+
self.add_pass(CastToInt32Pass())
8386

8487
self.add_pass(ReplaceScalarWithTensorArgPass())
8588
self.add_pass(AnnotateDecomposedMatmulPass())
@@ -94,7 +97,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9497
self.add_pass(SizeAdjustConv2DPass())
9598
self.add_pass(ConvertExpandCopyToRepeatPass())
9699
self.add_pass(UnsqueezeBeforeRepeatPass())
97-
self.add_pass(CastInt64ToInt32Pass(exported_program))
100+
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
98101
self.add_pass(KeepDimsFalseToSqueezePass())
99102
self.add_pass(Conv1dUnsqueezePass(exported_program))
100103
self.add_pass(DecomposeSelectPass())
@@ -141,7 +144,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
141144
self.add_pass(SizeAdjustConv2DPass())
142145
self.add_pass(ConvertExpandCopyToRepeatPass())
143146
self.add_pass(UnsqueezeBeforeRepeatPass())
144-
self.add_pass(CastInt64ToInt32Pass(exported_program))
147+
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
145148
self.add_pass(KeepDimsFalseToSqueezePass())
146149
self.add_pass(Conv1dUnsqueezePass(exported_program))
147150
self.add_pass(DecomposeSelectPass())

backends/arm/_passes/cast_int64_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
logger.setLevel(logging.WARNING)
1616

1717

18-
class CastInt64ToInt32Pass(ExportPass):
18+
class CastInt64BuffersToInt32Pass(ExportPass):
1919
"""
2020
Cast int64 buffers to int32 if the int64 data is in int32 range.
2121
"""
2222

2323
def __init__(self, exported_program: torch.export.ExportedProgram):
24-
super(CastInt64ToInt32Pass, self).__init__()
24+
super(CastInt64BuffersToInt32Pass, self).__init__()
2525
self.exported_program = exported_program
2626

2727
def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
class CastToInt32Pass(ExportPass):
13+
"""Casts the input to int32 if it is not already and casts back the output to the original input dtype."""
14+
15+
targeted_ops = {
16+
exir_ops.edge.aten.bitwise_left_shift.Tensor,
17+
exir_ops.edge.aten.bitwise_right_shift.Tensor,
18+
}
19+
20+
def call_operator(self, op, args, kwargs, meta):
21+
if op not in self.targeted_ops:
22+
return super().call_operator(op, args, kwargs, meta)
23+
24+
new_args: list = []
25+
did_cast = False
26+
for arg in args:
27+
if arg.data.dtype != torch.int32:
28+
new_args.append(
29+
super().call_operator(
30+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
31+
(arg,),
32+
{"dtype": torch.int32},
33+
meta,
34+
)
35+
)
36+
did_cast = True
37+
else:
38+
new_args.append(arg)
39+
40+
output = super().call_operator(
41+
op,
42+
tuple(new_args),
43+
{},
44+
meta,
45+
)
46+
47+
if did_cast:
48+
output = super().call_operator(
49+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
50+
(output,),
51+
{"dtype": args[0].data.dtype},
52+
meta,
53+
)
54+
return output

backends/arm/_passes/match_arg_ranks_pass.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def __init__(self, exported_program):
4545
exir_ops.edge.aten.sub.Tensor,
4646
exir_ops.edge.aten.mul.Tensor,
4747
exir_ops.edge.aten.div.Tensor,
48+
exir_ops.edge.aten.bitwise_right_shift.Tensor,
49+
exir_ops.edge.aten.bitwise_left_shift.Tensor,
50+
exir_ops.edge.aten.eq.Tensor,
4851
]
4952

5053
def _match_op_rank(self, graph_module, node, arg, max_rank):

backends/arm/operator_support/right_shift_support.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222

2323
@register_tosa_support_check
2424
class RightShiftSupported(SupportedTOSAOperatorCheck):
25-
targets = [exir_ops.edge.aten.__rshift__.Scalar]
25+
targets = [
26+
exir_ops.edge.aten.bitwise_right_shift.Tensor,
27+
exir_ops.edge.aten.__rshift__.Scalar,
28+
]
2629

2730
tosa_specs = [
2831
TosaSpecification.create_from_string("TOSA-0.80+BI"),

backends/arm/operator_support/tosa_supported_operators.py

+4
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def is_node_supported(
158158
exir_ops.edge.aten.hardswish.default,
159159
exir_ops.edge.aten.div.Tensor,
160160
exir_ops.edge.aten.eq.Tensor,
161+
exir_ops.edge.aten.eq.Scalar,
161162
exir_ops.edge.aten.exp.default,
162163
exir_ops.edge.aten.log.default,
163164
exir_ops.edge.aten.linear.default,
@@ -205,6 +206,8 @@ def is_node_supported(
205206
exir_ops.edge.aten.amin.default,
206207
exir_ops.edge.aten.eye.default,
207208
exir_ops.edge.aten.linspace.default,
209+
exir_ops.edge.aten.bitwise_left_shift.Tensor,
210+
exir_ops.edge.aten.__lshift__.Scalar,
208211
torch.ops.aten.scalar_tensor.default,
209212
]
210213

@@ -233,6 +236,7 @@ class EthosU55NotSupported(OperatorSupportBase):
233236
exir_ops.edge.aten.amax.default, # REDUCE_MAX
234237
exir_ops.edge.aten.amin.default, # REDUCE_MIN
235238
exir_ops.edge.aten.eq.Tensor,
239+
exir_ops.edge.aten.eq.Scalar,
236240
exir_ops.edge.aten.ge.Tensor,
237241
exir_ops.edge.aten.gt.Tensor,
238242
exir_ops.edge.aten.le.Tensor,

backends/arm/operators/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
op_reciprocal,
3636
op_repeat,
3737
op_rescale,
38-
op_rshift,
38+
op_rshift_tensor,
3939
op_rsqrt,
4040
op_sigmoid,
4141
op_slice,

backends/arm/operators/op_rshift.py

-100
This file was deleted.
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts # type: ignore
11+
import torch
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from executorch.backends.arm.tosa_specification import Tosa_0_80
18+
from serializer.tosa_serializer import TosaOp
19+
20+
21+
@register_node_visitor
22+
class RshiftVisitor(NodeVisitor):
23+
target = "aten.bitwise_right_shift.Tensor"
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
tosa_graph: ts.TosaSerializer,
29+
inputs: List[TosaArg],
30+
output: TosaArg,
31+
) -> None:
32+
33+
attr = ts.TosaSerializerAttribute()
34+
round = False
35+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
36+
# U55 only supports INT32 and round == True
37+
# TODO MLETORCH-525 Emulate round == False with different decomposition
38+
round = True
39+
attr.ArithmeticRightShiftAttribute(round=round)
40+
41+
tosa_graph.addOperator(
42+
TosaOp.Op().ARITHMETIC_RIGHT_SHIFT,
43+
[inputs[0].name, inputs[1].name],
44+
[output.name],
45+
attr,
46+
)

backends/arm/operators/ops_binary.py

+3
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@ def define_node(
5252
binary_operator_factory("aten.logical_and.default", TosaOp.Op().LOGICAL_AND)
5353
binary_operator_factory("aten.logical_xor.default", TosaOp.Op().LOGICAL_XOR)
5454
binary_operator_factory("aten.logical_or.default", TosaOp.Op().LOGICAL_OR)
55+
binary_operator_factory(
56+
"aten.bitwise_left_shift.Tensor", TosaOp.Op().LOGICAL_LEFT_SHIFT
57+
)

backends/arm/test/models/test_conformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@ class TestConformer(unittest.TestCase):
3131
# .to_executorch step, i.e. after Arm partitioner.
3232
ops_after_partitioner = {
3333
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
34-
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
3534
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
3635
"torch.ops.aten._assert_scalar.default": 10,
3736
"torch.ops.aten._local_scalar_dense.default": 1,
38-
"torch.ops.higher_order.executorch_call_delegate": 6,
37+
"torch.ops.higher_order.executorch_call_delegate": 4,
3938
}
4039

4140
dim = 16

backends/arm/test/models/test_llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_llama_tosa_MI(self):
114114
)
115115
.export()
116116
.to_edge_transform_and_lower()
117-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 26})
117+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 14})
118118
.to_executorch()
119119
.run_method_and_compare_outputs(
120120
inputs=llama_inputs,

0 commit comments

Comments
 (0)