Skip to content

Commit 17b933f

Browse files
Arm backend: Convert asserts to raise errors in op_avg_pool2d (#10516)
Asserts are converted to proper raises to ensure graph integrity. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 0fa0003 commit 17b933f

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

backends/arm/operators/op_avg_pool2d.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ def define_node(
8585
) -> None:
8686
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
8787

88-
input_tensor = inputs[0]
89-
assert input_tensor.dtype == ts.DType.INT8
88+
supported_dtypes = [ts.DType.INT8]
89+
if inputs[0].dtype not in supported_dtypes:
90+
raise TypeError(
91+
f"IO data type needs to be one of {supported_dtypes}, got "
92+
f'"{inputs[0].dtype}"'
93+
)
9094

9195
accumulator_type = ts.DType.INT32
9296

@@ -118,9 +122,12 @@ def define_node(
118122
) -> None:
119123
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120124

121-
assert (
122-
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
123-
), "Only FP32 and INT8 supported"
125+
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
126+
if inputs[0].dtype not in supported_dtypes:
127+
raise TypeError(
128+
f"IO data type needs to be one of {supported_dtypes}, got "
129+
f'"{inputs[0].dtype}"'
130+
)
124131

125132
if inputs[0].dtype == ts.DType.INT8:
126133
super().define_node(node, tosa_graph, inputs, output)
@@ -205,8 +212,12 @@ def define_node(
205212
) -> None:
206213
import serializer.tosa_serializer as ts # type: ignore
207214

208-
input_tensor = inputs[0]
209-
assert input_tensor.dtype == ts.DType.INT8
215+
supported_dtypes = [ts.DType.INT8]
216+
if inputs[0].dtype not in supported_dtypes:
217+
raise TypeError(
218+
f"IO data type needs to be one of {supported_dtypes}, got "
219+
f'"{inputs[0].dtype}"'
220+
)
210221

211222
accumulator_type = ts.DType.INT32
212223

@@ -241,9 +252,12 @@ def define_node(
241252
) -> None:
242253
import serializer.tosa_serializer as ts # type: ignore
243254

244-
assert (
245-
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
246-
), "Only FP32 and INT8 supported"
255+
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
256+
if inputs[0].dtype not in supported_dtypes:
257+
raise TypeError(
258+
f"IO data type needs to be one of {supported_dtypes}, got "
259+
f'"{inputs[0].dtype}"'
260+
)
247261

248262
if inputs[0].dtype == ts.DType.INT8:
249263
super().define_node(node, tosa_graph, inputs, output)

0 commit comments

Comments
 (0)