11
11
import torch .fx as fx
12
12
from executorch .backends .arm ._passes .arm_pass_utils import get_first_fake_tensor
13
13
from executorch .backends .arm ._passes .insert_table_ops import TableOps
14
+ from executorch .backends .arm .operators .op_permute import transform_permutation_vector
15
+ from executorch .backends .arm .tosa_utils import tosa_shape
14
16
from executorch .exir .backend .utils import WhyNoPartitionReporter
15
17
16
18
from executorch .exir .dialects ._ops import ops as exir_ops
17
19
from torch .fx .passes .operator_support import OperatorSupportBase
18
20
19
21
22
+ def _try_determine_dtype (node : fx .Node ) -> torch .dtype | None :
23
+ dtype = get_first_fake_tensor (node ).dtype
24
+ if not dtype .is_floating_point :
25
+ return dtype
26
+ if node .target is exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default :
27
+ return get_first_fake_tensor (node .all_input_nodes [0 ]).dtype
28
+ q_node = list (node .users )[0 ]
29
+ if q_node .target is exir_ops .edge .quantized_decomposed .quantize_per_tensor .default :
30
+ return typing .cast (torch .dtype , q_node .args [- 1 ])
31
+ # We can't easily figure out dtype, return None
32
+ return None
33
+
34
+
20
35
class EthosU55DtypeSupport (OperatorSupportBase ):
21
36
22
37
def __init__ (self , reporter : WhyNoPartitionReporter ):
@@ -33,37 +48,11 @@ def __init__(self, reporter: WhyNoPartitionReporter):
33
48
34
49
target_ops_i8 = tuple (TableOps .included_ops ())
35
50
36
- def _try_determine_dtype (self , node : fx .Node ) -> torch .dtype | None :
37
- """Attempt to figure out the quantized data type of node. On failure, return None."""
38
-
39
- dtype = get_first_fake_tensor (node ).dtype
40
- if not dtype .is_floating_point :
41
- return dtype
42
-
43
- if (
44
- node .target
45
- is exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
46
- ):
47
- return get_first_fake_tensor (node .all_input_nodes [0 ]).dtype
48
-
49
- if len (node .users ) == 0 :
50
- return None
51
-
52
- q_node = list (node .users )[0 ]
53
- if (
54
- q_node .target
55
- is exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
56
- ):
57
- return typing .cast (torch .dtype , q_node .args [- 1 ])
58
-
59
- # We can't easily figure out dtype, return None
60
- return None
61
-
62
51
def is_node_supported ( # noqa: C901
63
52
self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
64
53
) -> bool :
65
54
66
- dtype = self . _try_determine_dtype (node )
55
+ dtype = _try_determine_dtype (node )
67
56
if dtype is None :
68
57
# If we couldn't determine dtype, just return ok.
69
58
return True
@@ -84,21 +73,21 @@ def is_node_supported( # noqa: C901
84
73
85
74
if node .target == exir_ops .edge .aten .convolution .default :
86
75
ifm , weight = node .all_input_nodes [0 :2 ]
87
- ifm_dtype = self . _try_determine_dtype (ifm )
76
+ ifm_dtype = _try_determine_dtype (ifm )
88
77
if ifm_dtype is not None and ifm_dtype not in (torch .int8 , torch .int16 ):
89
78
self .reporter .report_reject (
90
79
node , f"Unsupported input dtype { dtype } (Supports i8, i16)."
91
80
)
92
81
return False
93
- weight_dtype = self . _try_determine_dtype (weight )
82
+ weight_dtype = _try_determine_dtype (weight )
94
83
if weight_dtype is not None and weight_dtype not in (torch .int8 ,):
95
84
self .reporter .report_reject (
96
85
node , f"Unsupported weight dtype { dtype } (Supports i8)."
97
86
)
98
87
return False
99
88
if len (node .all_input_nodes ) > 2 :
100
89
bias = node .all_input_nodes [2 ]
101
- bias_dtype = self . _try_determine_dtype (bias )
90
+ bias_dtype = _try_determine_dtype (bias )
102
91
if bias_dtype is not None and bias_dtype not in (torch .int32 ,):
103
92
self .reporter .report_reject (
104
93
node , f"Unsupported bias dtype { dtype } (Supports i32)."
@@ -110,7 +99,7 @@ def is_node_supported( # noqa: C901
110
99
exir_ops .edge .aten .bmm .default ,
111
100
):
112
101
for input_node in node .all_input_nodes :
113
- dtype = self . _try_determine_dtype (input_node )
102
+ dtype = _try_determine_dtype (input_node )
114
103
if dtype is not None and dtype != torch .int8 :
115
104
self .reporter .report_reject (
116
105
input_node ,
@@ -174,3 +163,114 @@ def is_node_supported(
174
163
return False
175
164
176
165
return True
166
+
167
+
168
+ shape_t = list [int ]
169
+
170
+
171
+ class EthosU55TransposeCheck (OperatorSupportBase ):
172
+
173
+ def __init__ (self , reporter : WhyNoPartitionReporter ):
174
+ super ().__init__ ()
175
+ self .reporter = reporter
176
+
177
+ def _pad_to_rank_4 (
178
+ self , shape : shape_t , permutation : list [int ]
179
+ ) -> tuple [shape_t , shape_t ]:
180
+ diff = 4 - len (shape )
181
+ padded_shape = [1 ] * diff + shape
182
+ for i in range (len (permutation )):
183
+ permutation [i ] += diff
184
+ padded_permutation = list (range (diff )) + permutation
185
+ return padded_shape , padded_permutation
186
+
187
+ def axes_product (self , nhwc_shape : shape_t ) -> int :
188
+ product = 1
189
+ for axes in nhwc_shape :
190
+ product *= axes
191
+ return product
192
+
193
+ def _permute_constraint_i8_i16 (
194
+ self , nhwc_shape : list [int ], permutation : list [int ]
195
+ ) -> bool :
196
+ """Returns True if the constraints are ok."""
197
+ N , H , W , C = nhwc_shape
198
+ match permutation :
199
+ case (0 , 1 , 2 , 3 ): # NHWC -> NHWC
200
+ return True
201
+ case (0 , 2 , 1 , 3 ) | (0 , 1 , 3 , 2 ) | (0 , 3 , 1 , 2 ): # NHWC -> NWHC, NHCW, NCWH
202
+ return N * H <= 65536 and W <= 65536 and C <= 65536
203
+ case _:
204
+ return self .axes_product (nhwc_shape ) <= 65536
205
+
206
+ def _permute_constraint_i32 (
207
+ self , nhwc_shape : list [int ], permutation : list [int ]
208
+ ) -> bool :
209
+ """Returns True if the constraints are ok."""
210
+ N , H , W , C = nhwc_shape
211
+ match permutation :
212
+ case (0 , 1 , 2 , 3 ): # NHWC -> NHWC
213
+ return C <= 32768
214
+ case (0 , 2 , 1 , 3 ): # NHWC -> NHWC
215
+ return N == 1 and H <= 65536 and W <= 65536 and C <= 16384
216
+ case (0 , 1 , 3 , 2 ): # NHWC -> NHCW
217
+ return N * H <= 65536 and W <= 65536 and C <= 65536
218
+ case _:
219
+ return False
220
+
221
+ def _permute_constraint (self , shape , permutation , dtype ):
222
+ if dtype in (torch .int8 , torch .int16 ):
223
+ return self ._permute_constraint_i8_i16 (shape , permutation )
224
+ if dtype == torch .int32 :
225
+ return not self ._permute_constraint_i32 (shape , permutation )
226
+ return True
227
+
228
+ def is_node_supported (
229
+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
230
+ ) -> bool :
231
+
232
+ if not node .target == exir_ops .edge .aten .permute_copy .default :
233
+ return True
234
+
235
+ shape = list (get_first_fake_tensor (node ).shape )
236
+ dtype = _try_determine_dtype (node )
237
+ permutation = list (typing .cast (list [int ], node .args [1 ]))
238
+
239
+ rank = len (shape )
240
+ if rank > 4 :
241
+ if dtype == torch .int32 :
242
+ self .reporter .report_reject (
243
+ node , f"No support for { permutation = } in int32."
244
+ )
245
+ return False
246
+ if dtype in (torch .int8 , torch .int16 ):
247
+ if self .axes_product (shape ) > 65536 :
248
+ self .reporter .report_reject (
249
+ node ,
250
+ f"No support for { shape = } , { dtype = } . Product of axes must be <65536" ,
251
+ )
252
+ return False
253
+ return True
254
+
255
+ shape , permutation = self ._pad_to_rank_4 (shape , permutation )
256
+ if rank == 3 or rank == 4 :
257
+ # For rank 3 and 4, we can have channels first or channels last dim order.
258
+ # Since we don't know which at partition-time, test both.
259
+
260
+ nhwc_shape = tosa_shape (shape , [0 , 2 , 3 , 1 ])
261
+ nhwc_permutation = transform_permutation_vector (permutation , [0 , 2 , 3 , 1 ])
262
+
263
+ if not self ._permute_constraint (nhwc_shape , nhwc_permutation , dtype ):
264
+ self .reporter .report_reject (
265
+ node ,
266
+ f"Unsupported NHWC { nhwc_shape = } for { nhwc_permutation = } , { dtype = } " ,
267
+ )
268
+ return False
269
+
270
+ if not self ._permute_constraint (shape , permutation , dtype ):
271
+ self .reporter .report_reject (
272
+ node , f"Unsupported NCHW { shape = } for { permutation = } , { dtype = } "
273
+ )
274
+ return False
275
+
276
+ return True
0 commit comments