Skip to content

Commit 135e875

Browse files
Arm backend: enable dim_order (#7831)
Add support for to_dim_order_copy With edge_compile_config.skip_dim_order = True removed, to_copy will be converted into to_dim_order_copy nodes. This commit moves our logic from to_copy into to_dim_order_copy. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 2516049 commit 135e875

File tree

11 files changed

+72
-53
lines changed

11 files changed

+72
-53
lines changed

backends/arm/operator_support/to_copy_support.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -22,7 +22,10 @@
2222

2323
@register_tosa_support_check
2424
class ToCopySupported(SupportedTOSAOperatorCheck):
25-
targets = [exir_ops.edge.aten._to_copy.default]
25+
targets = [
26+
exir_ops.edge.aten._to_copy.default,
27+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
28+
]
2629

2730
tosa_specs = [
2831
TosaSpecification.create_from_string("TOSA-0.80+BI"),
@@ -110,7 +113,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
110113
)
111114
return False
112115

113-
# Check memory format
116+
# Check memory format (to_copy)
114117
if "memory_format" in node.kwargs:
115118
if node.kwargs["memory_format"] in (torch.preserve_format,):
116119
logger.info(
@@ -119,4 +122,14 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
119122
)
120123
return False
121124

125+
# Check dim_order (to_dim_order_copy)
126+
if "dim_order" in node.kwargs:
127+
dim_order = node.kwargs["dim_order"]
128+
if dim_order != list(range(len(dim_order))):
129+
logger.info(
130+
f"Argument {dim_order=} is not supported for "
131+
f"{node.target.name()} right now." # pyre-ignore[16]
132+
)
133+
return False
134+
122135
return True

backends/arm/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
op_table,
3636
op_tanh,
3737
op_to_copy,
38+
op_to_dim_order_copy,
3839
op_transpose,
3940
op_upsample_nearest2d,
4041
op_view,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
import tosa.Op as TosaOp
12+
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
19+
20+
@register_node_visitor
21+
class ToDimOrderCopyVisitor(NodeVisitor):
22+
"""
23+
Implement the type cast functionality of _to_dim_order_copy.
24+
25+
Other features like setting of the dim_order or moving a tensor to a
26+
different device are not supported.
27+
28+
Also note that the node should not be quantized.
29+
"""
30+
31+
target = "dim_order_ops._to_dim_order_copy.default"
32+
33+
def define_node(
34+
self,
35+
node: torch.fx.Node,
36+
tosa_graph: ts.TosaSerializer,
37+
inputs: List[TosaArg],
38+
output: TosaArg,
39+
) -> None:
40+
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name])

backends/arm/test/models/test_mobilenet_v2_arm.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from executorch.backends.arm.test import common, conftest
1515

1616
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17-
from executorch.exir import EdgeCompileConfig
1817
from torchvision import models, transforms
1918
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
2019

@@ -47,10 +46,6 @@ class TestMobileNetV2(unittest.TestCase):
4746
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
4847
}
4948

50-
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
51-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
52-
)
53-
5449
def test_mv2_tosa_MI(self):
5550
(
5651
ArmTester(
@@ -59,7 +54,7 @@ def test_mv2_tosa_MI(self):
5954
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
6055
)
6156
.export()
62-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
57+
.to_edge_transform_and_lower()
6358
.to_executorch()
6459
.run_method_and_compare_outputs(inputs=self.model_inputs)
6560
)
@@ -73,7 +68,7 @@ def test_mv2_tosa_BI(self):
7368
)
7469
.quantize()
7570
.export()
76-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
71+
.to_edge_transform_and_lower()
7772
.to_executorch()
7873
# atol=1.0 is a defensive upper limit
7974
# TODO MLETROCH-72
@@ -92,7 +87,7 @@ def test_mv2_u55_BI(self):
9287
)
9388
.quantize()
9489
.export()
95-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
90+
.to_edge_transform_and_lower()
9691
.to_executorch()
9792
.serialize()
9893
)
@@ -112,7 +107,7 @@ def test_mv2_u85_BI(self):
112107
)
113108
.quantize()
114109
.export()
115-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
110+
.to_edge_transform_and_lower()
116111
.to_executorch()
117112
.serialize()
118113
)

backends/arm/test/ops/test_add.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414
from executorch.backends.arm.test import common, conftest
1515
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16-
from executorch.exir import EdgeCompileConfig
1716
from executorch.exir.backend.compile_spec_schema import CompileSpec
1817
from parameterized import parameterized
1918

@@ -51,10 +50,6 @@ def __init__(self):
5150
def forward(self, x, y):
5251
return x + y
5352

54-
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
55-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
56-
)
57-
5853
def _test_add_tosa_MI_pipeline(
5954
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
6055
):
@@ -67,7 +62,7 @@ def _test_add_tosa_MI_pipeline(
6762
.export()
6863
.check_count({"torch.ops.aten.add.Tensor": 1})
6964
.check_not(["torch.ops.quantized_decomposed"])
70-
.to_edge(config=self._edge_compile_config)
65+
.to_edge()
7166
.partition()
7267
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
7368
.to_executorch()
@@ -87,7 +82,7 @@ def _test_add_tosa_BI_pipeline(
8782
.export()
8883
.check_count({"torch.ops.aten.add.Tensor": 1})
8984
.check(["torch.ops.quantized_decomposed"])
90-
.to_edge(config=self._edge_compile_config)
85+
.to_edge()
9186
.partition()
9287
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9388
.to_executorch()

backends/arm/test/ops/test_linear.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from executorch.backends.arm.test import common, conftest
1717

1818
from executorch.backends.arm.test.tester.arm_tester import ArmTester
19-
from executorch.exir import EdgeCompileConfig
2019
from executorch.exir.backend.compile_spec_schema import CompileSpec
2120
from parameterized import parameterized
2221

@@ -108,10 +107,6 @@
108107
class TestLinear(unittest.TestCase):
109108
"""tests the linear operation y = Ax + b"""
110109

111-
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
112-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
113-
)
114-
115110
class Linear(torch.nn.Module):
116111
def __init__(
117112
self,
@@ -143,7 +138,7 @@ def _test_linear_tosa_MI_pipeline(
143138
.export()
144139
.check_count({"torch.ops.aten.linear.default": 1})
145140
.check_not(["torch.ops.quantized_decomposed"])
146-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
141+
.to_edge_transform_and_lower()
147142
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
148143
.to_executorch()
149144
.run_method_and_compare_outputs(inputs=test_data)
@@ -164,7 +159,7 @@ def _test_linear_tosa_BI_pipeline(
164159
.export()
165160
.check_count({"torch.ops.aten.linear.default": 1})
166161
.check(["torch.ops.quantized_decomposed"])
167-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
162+
.to_edge_transform_and_lower()
168163
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
169164
.to_executorch()
170165
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
@@ -186,7 +181,7 @@ def _test_linear_tosa_ethosu_BI_pipeline(
186181
.export()
187182
.check_count({"torch.ops.aten.linear.default": 1})
188183
.check(["torch.ops.quantized_decomposed"])
189-
.to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config)
184+
.to_edge_transform_and_lower()
190185
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
191186
.to_executorch()
192187
.serialize()

backends/arm/test/ops/test_maximum.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
from executorch.backends.arm.test import common, conftest
1414
from executorch.backends.arm.test.tester.arm_tester import ArmTester
15-
from executorch.exir import EdgeCompileConfig
1615
from executorch.exir.backend.compile_spec_schema import CompileSpec
1716
from parameterized import parameterized
1817

@@ -38,10 +37,6 @@ def __init__(self):
3837
def forward(self, x, y):
3938
return torch.maximum(x, y)
4039

41-
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
42-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
43-
)
44-
4540
def _test_maximum_tosa_MI_pipeline(
4641
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
4742
):
@@ -54,7 +49,7 @@ def _test_maximum_tosa_MI_pipeline(
5449
.export()
5550
.check_count({"torch.ops.aten.maximum.default": 1})
5651
.check_not(["torch.ops.quantized_decomposed"])
57-
.to_edge(config=self._edge_compile_config)
52+
.to_edge()
5853
.partition()
5954
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6055
.to_executorch()
@@ -74,7 +69,7 @@ def _test_maximum_tosa_BI_pipeline(
7469
.export()
7570
.check_count({"torch.ops.aten.maximum.default": 1})
7671
.check(["torch.ops.quantized_decomposed"])
77-
.to_edge(config=self._edge_compile_config)
72+
.to_edge()
7873
.partition()
7974
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
8075
.to_executorch()

backends/arm/test/ops/test_minimum.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
from executorch.backends.arm.test import common, conftest
1414
from executorch.backends.arm.test.tester.arm_tester import ArmTester
15-
from executorch.exir import EdgeCompileConfig
1615
from executorch.exir.backend.compile_spec_schema import CompileSpec
1716
from parameterized import parameterized
1817

@@ -38,10 +37,6 @@ def __init__(self):
3837
def forward(self, x, y):
3938
return torch.minimum(x, y)
4039

41-
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
42-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
43-
)
44-
4540
def _test_minimum_tosa_MI_pipeline(
4641
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
4742
):
@@ -54,7 +49,7 @@ def _test_minimum_tosa_MI_pipeline(
5449
.export()
5550
.check_count({"torch.ops.aten.minimum.default": 1})
5651
.check_not(["torch.ops.quantized_decomposed"])
57-
.to_edge(config=self._edge_compile_config)
52+
.to_edge()
5853
.partition()
5954
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6055
.to_executorch()
@@ -74,7 +69,7 @@ def _test_minimum_tosa_BI_pipeline(
7469
.export()
7570
.check_count({"torch.ops.aten.minimum.default": 1})
7671
.check(["torch.ops.quantized_decomposed"])
77-
.to_edge(config=self._edge_compile_config)
72+
.to_edge()
7873
.partition()
7974
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
8075
.to_executorch()

backends/arm/test/ops/test_sum.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212
from executorch.backends.arm.test import common
1313
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14-
from executorch.exir import EdgeCompileConfig
1514
from executorch.exir.backend.compile_spec_schema import CompileSpec
1615
from parameterized import parameterized
1716

@@ -47,10 +46,6 @@ class Sum(torch.nn.Module):
4746
def forward(self, x: torch.Tensor, dim: int, keepdim: bool):
4847
return x.sum(dim=dim, keepdim=keepdim)
4948

50-
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
51-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
52-
)
53-
5449
def _test_sum_tosa_MI_pipeline(
5550
self, module: torch.nn.Module, test_data: tuple[exampledata_t]
5651
):
@@ -63,7 +58,7 @@ def _test_sum_tosa_MI_pipeline(
6358
.export()
6459
.check_count({"torch.ops.aten.sum.dim_IntList": 1})
6560
.check_not(["torch.ops.quantized_decomposed"])
66-
.to_edge(config=self._edge_compile_config)
61+
.to_edge()
6762
.partition()
6863
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6964
.to_executorch()
@@ -83,7 +78,7 @@ def _test_sum_tosa_BI_pipeline(
8378
.export()
8479
.check_count({"torch.ops.aten.sum.dim_IntList": 1})
8580
.check(["torch.ops.quantized_decomposed"])
86-
.to_edge(config=self._edge_compile_config)
81+
.to_edge()
8782
.partition()
8883
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
8984
.to_executorch()

backends/arm/test/tester/arm_tester.py

-3
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,6 @@ def to_edge(
227227
if config is not None:
228228
to_edge_stage.edge_compile_conf = config
229229

230-
# TODO(T182928844): Delegate dim order op to backend.
231-
to_edge_stage.edge_compile_conf._skip_dim_order = True
232230
return super().to_edge(to_edge_stage)
233231

234232
def partition(self, partition_stage: Optional[Partition] = None):
@@ -254,7 +252,6 @@ def to_edge_transform_and_lower(
254252
to_edge_and_lower_stage.partitioners = partitioners
255253
if edge_compile_config is not None:
256254
to_edge_and_lower_stage.edge_compile_conf = edge_compile_config
257-
to_edge_and_lower_stage.edge_compile_conf._skip_dim_order = True
258255
return super().to_edge_transform_and_lower(to_edge_and_lower_stage)
259256

260257
def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None):

examples/arm/aot_arm_compiler.py

-2
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,6 @@ def get_args():
527527
partitioner=[ArmPartitioner(compile_spec)],
528528
compile_config=EdgeCompileConfig(
529529
_check_ir_validity=False,
530-
_skip_dim_order=True,
531530
),
532531
)
533532

@@ -553,7 +552,6 @@ def get_args():
553552
exported_program,
554553
compile_config=EdgeCompileConfig(
555554
_check_ir_validity=False,
556-
_skip_dim_order=True,
557555
),
558556
)
559557

0 commit comments

Comments
 (0)