Skip to content

Commit 94a8bc7

Browse files
Arm backend: Add logical And, Or, and Xor operators in Arm backend (#9036)
1. Implement the logical operatos using binary operator factory. 2. Disable the logical operators for EthosU55. Signed-off-by: Yufeng Shi <[email protected]>
1 parent 5344a1a commit 94a8bc7

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def is_node_supported(
112112
supported = node.op == "call_function" and node.target in [
113113
exir_ops.edge.aten.abs.default,
114114
exir_ops.edge.aten.add.Tensor,
115+
exir_ops.edge.aten.logical_and.default,
116+
exir_ops.edge.aten.logical_or.default,
117+
exir_ops.edge.aten.logical_xor.default,
115118
exir_ops.edge.aten.bitwise_and.Tensor,
116119
exir_ops.edge.aten.bitwise_or.Tensor,
117120
exir_ops.edge.aten.bitwise_xor.Tensor,
@@ -193,6 +196,9 @@ def is_node_supported(
193196
exir_ops.edge.aten.bitwise_and.Tensor,
194197
exir_ops.edge.aten.bitwise_or.Tensor,
195198
exir_ops.edge.aten.bitwise_xor.Tensor,
199+
exir_ops.edge.aten.logical_and.default,
200+
exir_ops.edge.aten.logical_or.default,
201+
exir_ops.edge.aten.logical_xor.default,
196202
exir_ops.edge.aten.amax.default,
197203
exir_ops.edge.aten.amin.default,
198204
exir_ops.edge.aten.eq.Tensor,

backends/arm/operators/ops_binary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,6 @@ def define_node(
4949
binary_operator_factory("aten.bitwise_and.Tensor", TosaOp.Op().BITWISE_AND)
5050
binary_operator_factory("aten.bitwise_xor.Tensor", TosaOp.Op().BITWISE_XOR)
5151
binary_operator_factory("aten.bitwise_or.Tensor", TosaOp.Op().BITWISE_OR)
52+
binary_operator_factory("aten.logical_and.default", TosaOp.Op().LOGICAL_AND)
53+
binary_operator_factory("aten.logical_xor.default", TosaOp.Op().LOGICAL_XOR)
54+
binary_operator_factory("aten.logical_or.default", TosaOp.Op().LOGICAL_OR)

backends/arm/test/ops/test_logical.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU85PipelineBI,
13+
OpNotSupportedPipeline,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
19+
class And(torch.nn.Module):
20+
aten_op = "torch.ops.aten.logical_and.default"
21+
exir_op = "executorch_exir_dialects_edge__ops_aten_logical_and_default"
22+
23+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
24+
return tensor1.logical_and(tensor2)
25+
26+
27+
class Xor(torch.nn.Module):
28+
aten_op = "torch.ops.aten.logical_xor.default"
29+
exir_op = "executorch_exir_dialects_edge__ops_aten_logical_xor_default"
30+
31+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
32+
return tensor1.logical_xor(tensor2)
33+
34+
35+
class Or(torch.nn.Module):
36+
aten_op = "torch.ops.aten.logical_or.default"
37+
exir_op = "executorch_exir_dialects_edge__ops_aten_logical_or_default"
38+
39+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
40+
return tensor1.logical_or(tensor2)
41+
42+
43+
input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
44+
45+
46+
test_input: dict[input_t2] = {
47+
"rank1": (
48+
torch.tensor([True, True, False, False], dtype=torch.bool),
49+
torch.tensor([True, False, True, False], dtype=torch.bool),
50+
),
51+
"rand_rank2": (
52+
torch.randint(0, 2, (10, 10), dtype=torch.bool),
53+
torch.randint(0, 2, (10, 10), dtype=torch.bool),
54+
),
55+
"rand_rank3": (
56+
torch.randint(0, 2, (10, 10, 10), dtype=torch.bool),
57+
torch.randint(0, 2, (10, 10, 10), dtype=torch.bool),
58+
),
59+
"rand_rank4": (
60+
torch.randint(0, 2, (1, 10, 10, 10), dtype=torch.bool),
61+
torch.randint(0, 2, (1, 10, 10, 10), dtype=torch.bool),
62+
),
63+
}
64+
65+
66+
test_data = {
67+
"and_rank1": (And(), test_input["rank1"]),
68+
"and_rand_rank2": (And(), test_input["rand_rank2"]),
69+
"and_rand_rank3": (And(), test_input["rand_rank3"]),
70+
"and_rand_rank4": (And(), test_input["rand_rank4"]),
71+
"xor_rank1": (Xor(), test_input["rank1"]),
72+
"xor_rand_rank2": (Xor(), test_input["rand_rank2"]),
73+
"xor_rand_rank3": (Xor(), test_input["rand_rank3"]),
74+
"xor_rand_rank4": (Xor(), test_input["rand_rank4"]),
75+
"or_rank1": (Or(), test_input["rank1"]),
76+
"or_rand_rank2": (Or(), test_input["rand_rank2"]),
77+
"or_rand_rank3": (Or(), test_input["rand_rank3"]),
78+
"or_rand_rank4": (Or(), test_input["rand_rank4"]),
79+
}
80+
81+
82+
fvp_xfails = {
83+
"and_rank1": "MLETORCH-706 Support ScalarType::Bool in EthosUBackend.",
84+
"and_rand_rank2": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
85+
"and_rand_rank3": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
86+
"and_rand_rank4": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
87+
"xor_rank1": "MLETORCH-706 Support ScalarType::Bool in EthosUBackend.",
88+
"xor_rand_rank2": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
89+
"xor_rand_rank3": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
90+
"xor_rand_rank4": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
91+
"or_rank1": "MLETORCH-706 Support ScalarType::Bool in EthosUBackend.",
92+
"or_rand_rank2": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
93+
"or_rand_rank3": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
94+
"or_rand_rank4": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
95+
}
96+
97+
98+
@common.parametrize("test_data", test_data)
99+
def test_logical_tosa_MI(test_data: input_t2):
100+
op, test_input = test_data
101+
pipeline = TosaPipelineMI[input_t2](op, test_input, op.aten_op, op.exir_op)
102+
pipeline.run()
103+
104+
105+
@common.parametrize("test_data", test_data)
106+
def test_logical_tosa_BI(test_data: input_t2):
107+
op, test_input = test_data
108+
pipeline = TosaPipelineBI[input_t2](op, test_input, op.aten_op, op.exir_op)
109+
pipeline.pop_stage(pipeline.find_pos("quantize") + 1)
110+
pipeline.pop_stage("quantize")
111+
pipeline.run()
112+
113+
114+
@common.parametrize("test_data", test_data)
115+
def test_logical_u55_BI(test_data: input_t2):
116+
# Tests that we don't delegate these ops since they are not supported on U55.
117+
op, test_input = test_data
118+
pipeline = OpNotSupportedPipeline[input_t2](
119+
op, test_input, "TOSA-0.80+BI+u55", {op.exir_op: 1}
120+
)
121+
pipeline.run()
122+
123+
124+
@common.parametrize("test_data", test_data)
125+
def test_logical_u85_BI(test_data: input_t2):
126+
op, test_input = test_data
127+
pipeline = EthosU85PipelineBI[input_t2](
128+
op, test_input, op.aten_op, op.exir_op, run_on_fvp=False
129+
)
130+
pipeline.pop_stage(pipeline.find_pos("quantize") + 1)
131+
pipeline.pop_stage("quantize")
132+
pipeline.run()
133+
134+
135+
@common.parametrize("test_data", test_data, fvp_xfails)
136+
@common.SkipIfNoCorstone320
137+
def test_logical_u85_BI_on_fvp(test_data: input_t2):
138+
op, test_input = test_data
139+
pipeline = EthosU85PipelineBI[input_t2](
140+
op, test_input, op.aten_op, op.exir_op, run_on_fvp=True
141+
)
142+
pipeline.pop_stage(pipeline.find_pos("quantize") + 1)
143+
pipeline.pop_stage("quantize")
144+
pipeline.run()

0 commit comments

Comments
 (0)