4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
# pyre-unsafe
7
- from typing import List
7
+ from typing import Any , List
8
8
9
9
import executorch .backends .arm .tosa_quant_utils as tqutils
10
10
import executorch .backends .arm .tosa_utils as tutils
11
11
12
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
13
12
from executorch .backends .arm .operators .node_visitor import (
14
13
NodeVisitor ,
15
14
register_node_visitor ,
@@ -33,10 +32,13 @@ def __init__(self, *args):
33
32
def define_node (
34
33
self ,
35
34
node : Node ,
36
- tosa_graph : ts . TosaSerializer ,
35
+ tosa_graph : Any ,
37
36
inputs : List [TosaArg ],
38
37
output : TosaArg ,
39
38
) -> None :
39
+
40
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
41
+
40
42
# Specification (0.80) states that input and output types
41
43
# should all be the same
42
44
if not (inputs [0 ].dtype == output .dtype ):
@@ -53,7 +55,7 @@ def define_node(
53
55
if inputs [0 ].dtype == ts .DType .INT8 :
54
56
rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
55
57
tosa_graph , inputs , node
56
- )
58
+ ) # type: ignore[possibly-undefined]
57
59
else :
58
60
# input[0].dtype == ts.DType.INT32
59
61
# Non quantized input, natively support by TOSA.abs
@@ -96,10 +98,13 @@ def __init__(self, *args):
96
98
def define_node (
97
99
self ,
98
100
node : Node ,
99
- tosa_graph : ts . TosaSerializer ,
101
+ tosa_graph : Any ,
100
102
inputs : List [TosaArg ],
101
103
output : TosaArg ,
102
104
) -> None :
105
+
106
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
107
+
103
108
# Specification (0.80) states that input and output types
104
109
# should all be the same
105
110
if not (inputs [0 ].dtype == output .dtype ):
@@ -129,3 +134,122 @@ def define_node(
129
134
[output .name ],
130
135
None ,
131
136
)
137
+
138
+
139
+ @register_node_visitor
140
+ class AbsVisitor_INT (NodeVisitor ):
141
+ target = "aten.abs.default"
142
+
143
+ tosa_specs = [
144
+ TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
145
+ ]
146
+
147
+ def __init__ (self , * args ):
148
+ super ().__init__ (* args )
149
+
150
+ def define_node (
151
+ self ,
152
+ node : Node ,
153
+ tosa_graph : Any ,
154
+ inputs : List [TosaArg ],
155
+ output : TosaArg ,
156
+ ) -> None :
157
+
158
+ import serializer .tosa_serializer as ts # type: ignore
159
+
160
+ # Specification (1.0) states that input and output types
161
+ # should all be the same
162
+ if not (inputs [0 ].dtype == output .dtype ):
163
+ raise ValueError (
164
+ "All inputs and outputs need same dtype."
165
+ f"Got { inputs [0 ].dtype = } , { output .dtype = } "
166
+ )
167
+ # Handle int8 (quantized) and int32
168
+ if not (inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]):
169
+ raise ValueError (
170
+ "All inputs need to be INT8 or INT32." f"Got { inputs [0 ].dtype = } "
171
+ )
172
+
173
+ scale_back = 1.0
174
+ if inputs [0 ].dtype == ts .DType .INT8 :
175
+ rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
176
+ tosa_graph , inputs , node , self .tosa_specs
177
+ ) # type: ignore[possibly-undefined]
178
+ else :
179
+ # input[0].dtype == ts.DType.INT32
180
+ # Non quantized input, natively support by TOSA.abs
181
+ rescaled_inputs = inputs
182
+
183
+ if output .dtype == ts .DType .INT8 :
184
+ broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
185
+ abs_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
186
+ else :
187
+ # output.dtype == ts.DType.INT32
188
+ abs_output = output
189
+
190
+ # Do the INT32 Abs
191
+ tosa_graph .addOperator (
192
+ ts .TosaOp .Op ().ABS ,
193
+ [
194
+ rescaled_inputs [0 ].name ,
195
+ ],
196
+ [abs_output .name ],
197
+ None ,
198
+ )
199
+
200
+ if output .dtype == ts .DType .INT8 :
201
+ # Scale output back to 8 bit
202
+ # pyre-ignore
203
+ tqutils .insert_rescale_op_to_int8 (
204
+ tosa_graph , abs_output , scale_back , node , self .tosa_specs
205
+ ) # type: ignore[possibly-undefined]
206
+
207
+
208
+ @register_node_visitor
209
+ class AbsVisitor_FP (AbsVisitor_INT ):
210
+ # inheriting 'target' from BI class
211
+
212
+ tosa_specs = [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
213
+
214
+ def __init__ (self , * args ):
215
+ super ().__init__ (* args )
216
+
217
+ def define_node (
218
+ self ,
219
+ node : Node ,
220
+ tosa_graph : Any ,
221
+ inputs : List [TosaArg ],
222
+ output : TosaArg ,
223
+ ) -> None :
224
+
225
+ import serializer .tosa_serializer as ts # type: ignore
226
+
227
+ # Specification (1.0) states that input and output types
228
+ # should all be the same
229
+ if not (inputs [0 ].dtype == output .dtype ):
230
+ raise ValueError (
231
+ "All inputs and output need same dtype."
232
+ f"Got { inputs [0 ].dtype = } , { output .dtype = } "
233
+ )
234
+
235
+ if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
236
+ # Call the inherited define_node for handling integers
237
+ super ().define_node (node , tosa_graph , inputs , output )
238
+ else :
239
+ # FP32 Abs lowering
240
+
241
+ if not (inputs [0 ].dtype == ts .DType .FP32 ):
242
+ raise ValueError (
243
+ "All inputs need to be FP32." f"Got { inputs [0 ].dtype = } "
244
+ )
245
+
246
+ if not (output .dtype == ts .DType .FP32 ):
247
+ raise ValueError ("All outputs need to be FP32." f"Got { output .dtype = } " )
248
+
249
+ # MI lowering
250
+ tosa_graph .addOperator (
251
+ ts .TosaOp .Op ().ABS ,
252
+ [inputs [0 ].name ],
253
+ [output .name ],
254
+ None ,
255
+ )
0 commit comments