@@ -85,8 +85,12 @@ def define_node(
85
85
) -> None :
86
86
import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
87
87
88
- input_tensor = inputs [0 ]
89
- assert input_tensor .dtype == ts .DType .INT8
88
+ supported_dtypes = [ts .DType .INT8 ]
89
+ if inputs [0 ].dtype not in supported_dtypes :
90
+ raise TypeError (
91
+ f"IO data type needs to be one of { supported_dtypes } , got "
92
+ f'"{ inputs [0 ].dtype } "'
93
+ )
90
94
91
95
accumulator_type = ts .DType .INT32
92
96
@@ -118,9 +122,12 @@ def define_node(
118
122
) -> None :
119
123
import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120
124
121
- assert (
122
- inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
123
- ), "Only FP32 and INT8 supported"
125
+ supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
126
+ if inputs [0 ].dtype not in supported_dtypes :
127
+ raise TypeError (
128
+ f"IO data type needs to be one of { supported_dtypes } , got "
129
+ f'"{ inputs [0 ].dtype } "'
130
+ )
124
131
125
132
if inputs [0 ].dtype == ts .DType .INT8 :
126
133
super ().define_node (node , tosa_graph , inputs , output )
@@ -205,8 +212,12 @@ def define_node(
205
212
) -> None :
206
213
import serializer .tosa_serializer as ts # type: ignore
207
214
208
- input_tensor = inputs [0 ]
209
- assert input_tensor .dtype == ts .DType .INT8
215
+ supported_dtypes = [ts .DType .INT8 ]
216
+ if inputs [0 ].dtype not in supported_dtypes :
217
+ raise TypeError (
218
+ f"IO data type needs to be one of { supported_dtypes } , got "
219
+ f'"{ inputs [0 ].dtype } "'
220
+ )
210
221
211
222
accumulator_type = ts .DType .INT32
212
223
@@ -241,9 +252,12 @@ def define_node(
241
252
) -> None :
242
253
import serializer .tosa_serializer as ts # type: ignore
243
254
244
- assert (
245
- inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
246
- ), "Only FP32 and INT8 supported"
255
+ supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
256
+ if inputs [0 ].dtype not in supported_dtypes :
257
+ raise TypeError (
258
+ f"IO data type needs to be one of { supported_dtypes } , got "
259
+ f'"{ inputs [0 ].dtype } "'
260
+ )
247
261
248
262
if inputs [0 ].dtype == ts .DType .INT8 :
249
263
super ().define_node (node , tosa_graph , inputs , output )
0 commit comments