6
6
7
7
import logging
8
8
from copy import deepcopy
9
- from typing import Set
9
+ from typing import Any , Set
10
10
11
11
import executorch .backends .vulkan .utils as utils
12
12
@@ -190,20 +190,24 @@ def propose_node_layout(
190
190
return next (iter (valid_layouts ))
191
191
192
192
def should_annotate (self , node ) -> bool :
193
- if not isinstance (node , torch .fx .Node ):
194
- return False
195
-
196
- if not utils .is_tensor_node (node ):
197
- return False
198
-
199
- # Storage type and memory layout for tensorref will be determined at runtime
200
- # so there's no use in setting those attributes ahead of time.
201
- if node .meta .get ("vkdg_tensorref" , False ):
202
- return False
203
-
204
- # Skip annotating output node. The output tensors should be annotated by the
205
- # time the output node is observed.
206
- if node .op == "output" :
193
+ if isinstance (node , torch .fx .Node ):
194
+ if not utils .is_tensor_node (node ):
195
+ return False
196
+
197
+ # Storage type and memory layout for tensorref will be determined at runtime
198
+ # so there's no use in setting those attributes ahead of time.
199
+ if node .meta .get ("vkdg_tensorref" , False ):
200
+ return False
201
+
202
+ # Skip annotating output node. The output tensors should be annotated by the
203
+ # time the output node is observed.
204
+ if node .op == "output" :
205
+ return False
206
+ elif isinstance (node , (list , tuple )):
207
+ return all (
208
+ isinstance (n , torch .fx .Node ) and self .should_annotate (n ) for n in node
209
+ )
210
+ else :
207
211
return False
208
212
209
213
return True
@@ -215,6 +219,70 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
215
219
# time the prepack node is observed.
216
220
return node .target == exir_ops .edge .et_vk .prepack .default
217
221
222
+ def set_or_transition_arg_node (
223
+ self ,
224
+ i : int ,
225
+ arg : torch .fx .Node ,
226
+ node : torch .fx .Node ,
227
+ graph_module : torch .fx .GraphModule ,
228
+ dirty : bool ,
229
+ ) -> bool :
230
+ assert isinstance (arg , torch .fx .Node )
231
+
232
+ storage = utils .get_node_storage_type (node )
233
+ assert storage is not None
234
+ layout = utils .get_node_memory_layout (node )
235
+ assert layout is not None
236
+
237
+ arg_storage = utils .get_node_storage_type (arg )
238
+ arg_layout = utils .get_node_memory_layout (arg )
239
+
240
+ if arg_storage is None :
241
+ utils .set_node_spec_attr (arg , "vk_storage_type" , storage )
242
+ arg_storage = storage
243
+ if arg_layout is None :
244
+ utils .set_node_spec_attr (arg , "vk_memory_layout" , layout )
245
+ arg_layout = layout
246
+
247
+ if arg_storage == storage and arg_layout == layout :
248
+ return False
249
+
250
+ if not dirty :
251
+ logger .info (
252
+ f"[Vulkan Delegate] Inserting transition(s) for { node .format_node ()} :"
253
+ )
254
+
255
+ insert_transition_node (graph_module , node , arg , storage , layout )
256
+
257
+ logger .info (
258
+ f" args { i } ({ arg } ): ({ arg_storage } , { arg_layout } ) -> ({ storage } , { layout } )"
259
+ )
260
+
261
+ return True
262
+
263
+ def set_or_transition_arg (
264
+ self ,
265
+ i : int ,
266
+ arg : Any ,
267
+ node : torch .fx .Node ,
268
+ graph_module : torch .fx .GraphModule ,
269
+ dirty : bool ,
270
+ ) -> bool :
271
+ if isinstance (arg , torch .fx .Node ):
272
+ return self .set_or_transition_arg_node (i , arg , node , graph_module , dirty )
273
+ elif isinstance (arg , (list , tuple )):
274
+ need_transition = False
275
+ for arg_node in arg :
276
+ need_transition = (
277
+ self .set_or_transition_arg_node (
278
+ i , arg_node , node , graph_module , need_transition
279
+ )
280
+ or need_transition
281
+ )
282
+ return need_transition
283
+ else :
284
+ return False
285
+
218
286
# noqa
219
287
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
220
288
for node in graph_module .graph .nodes :
@@ -226,36 +294,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
226
294
227
295
set_memory_metadata (node , storage , layout )
228
296
229
- inserting_transitions_for_node = False
297
+ need_transition = False
230
298
for i , arg in enumerate (node .args ):
231
299
if not self .should_annotate (arg ):
232
300
continue
233
301
234
- assert isinstance (arg , torch .fx .Node )
235
-
236
- arg_storage = utils .get_node_storage_type (arg )
237
- arg_layout = utils .get_node_memory_layout (arg )
238
-
239
- if arg_storage is None :
240
- utils .set_node_spec_attr (arg , "vk_storage_type" , storage )
241
- arg_storage = storage
242
- if arg_layout is None :
243
- utils .set_node_spec_attr (arg , "vk_memory_layout" , layout )
244
- arg_layout = layout
245
-
246
- if arg_storage == storage and arg_layout == layout :
247
- continue
248
-
249
- if not inserting_transitions_for_node :
250
- inserting_transitions_for_node = True
251
- logger .info (
252
- f"[Vulkan Delegate] Inserting transition(s) for { node .format_node ()} :"
302
+ need_transition = (
303
+ self .set_or_transition_arg (
304
+ i , arg , node , graph_module , need_transition
253
305
)
254
-
255
- insert_transition_node (graph_module , node , arg , storage , layout )
256
-
257
- logger .info (
258
- f" args { i } ({ arg } ): ({ arg_storage } , { arg_layout } ) -> ({ storage } , { layout } )"
306
+ or need_transition
259
307
)
260
308
261
309
return PassResult (graph_module , True )
0 commit comments