Skip to content

Commit 265b9b7

Browse files
authored
Arm backend: Extend convolution support check to 3d (#9640)
Add conv3d tests, though most are skipped since conv3d support is not yet implemented. Signed-off-by: Erik Lundell <[email protected]>
1 parent 7159650 commit 265b9b7

File tree

3 files changed

+424
-19
lines changed

3 files changed

+424
-19
lines changed

backends/arm/operator_support/convolution_support.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _is_node_supported_u55(self, node: fx.Node):
5555

5656
C_in = shape_in[1]
5757
C_out = shape_out[1]
58-
if (C_in == group) and (C_out % C_in) == 0:
58+
if (C_in == group) and (C_out % C_in) == 0 and len(shape_in) <= 4:
5959
# Depthwise convolution
6060
for dim in shape_in[1:]:
6161
if not 1 <= dim <= 65536:
@@ -74,13 +74,19 @@ def _is_node_supported_u55(self, node: fx.Node):
7474

7575
kernel_w = kernel[2]
7676
kernel_h = kernel[3] if len(kernel) > 3 else 1
77+
kernel_z = kernel[4] if len(kernel) > 4 else 1
7778
# Kernel condition misses constraint on sum of absolute weights
7879
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
7980
self.reporter.report_reject(
8081
node,
8182
f"Convolution needs to have kernel_y<=64, kernel_x*kernel_y<=4096, got kernel ({kernel_w}, {kernel_h})",
8283
)
8384
return False
85+
if kernel_z != 1:
86+
self.reporter.report_reject(
87+
node, f"Convolution3d needs to have kernel_z==1, got {kernel_z}."
88+
)
89+
return False
8490

8591
if not self._stride_condition(node):
8692
self.reporter.report_reject(
@@ -107,6 +113,14 @@ def _stride_condition(self, node: fx.Node) -> bool:
107113
if len(strides) == 1:
108114
strides = [strides[0]] * 2
109115

116+
if len(strides) > 2:
117+
stride_z = strides[2]
118+
if stride_z > 1:
119+
self.reporter.report_reject(
120+
node, f"Convolution3d only supports stride_z<=1, got {stride_z}."
121+
)
122+
return False
123+
110124
for stride, dilation in zip(strides, dilations):
111125
stride_condition = 1 <= stride <= 3
112126
dilation_condition = (not has_padding) and (dilation == 1)

backends/arm/test/ops/test_conv2d.py

+10-18
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010
from executorch.backends.arm.test import common
11-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1211
from executorch.backends.arm.test.tester.test_pipeline import (
1312
EthosU55PipelineBI,
1413
EthosU85PipelineBI,
14+
OpNotSupportedPipeline,
1515
TosaPipelineBI,
1616
TosaPipelineMI,
1717
)
@@ -34,9 +34,9 @@ def __init__(
3434
in_channels: Union[List, int, None] = None,
3535
out_channels: Union[List, int, None] = None,
3636
kernel_size: Union[List, Tuple, None] = None,
37-
stride: Union[List, Tuple, None] = None,
38-
padding: Union[List, Tuple, None] = None,
39-
dilation: Union[List, Tuple, None] = None,
37+
stride: Union[List, Tuple, int, None] = None,
38+
padding: Union[List, Tuple, int, None] = None,
39+
dilation: Union[List, Tuple, int, None] = None,
4040
groups: Union[List, int, None] = None,
4141
bias: Union[List, bool, None] = None,
4242
padding_mode: Union[List, str, None] = None,
@@ -446,17 +446,9 @@ def test_convolution_2d_u85_BI_on_fvp(test_module):
446446
def test_reject_convolution_2d_u55_BI(
447447
module: Conv2d,
448448
):
449-
(
450-
ArmTester(
451-
module,
452-
example_inputs=module.get_inputs(),
453-
compile_spec=common.get_u55_compile_spec(),
454-
)
455-
.quantize()
456-
.export()
457-
.check_count({"torch.ops.aten.conv2d.default": 1})
458-
.check(["torch.ops.quantized_decomposed"])
459-
.to_edge_transform_and_lower()
460-
.check(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
461-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
462-
)
449+
OpNotSupportedPipeline(
450+
module,
451+
module.get_inputs(),
452+
"TOSA-0.80+BI+u55",
453+
{"executorch_exir_dialects_edge__ops_aten_convolution_default": 1},
454+
).run()

0 commit comments

Comments
 (0)