Skip to content

Commit 791fe7f

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Allow memory tagging pass to handle nodes with list of tensor args (#9203)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9173 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/196/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/196/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/196/orig @diff-train-skip-merge Co-authored-by: Stephen Jia <[email protected]>
1 parent 4de19fd commit 791fe7f

File tree

1 file changed

+88
-40
lines changed

1 file changed

+88
-40
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

+88-40
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
from copy import deepcopy
9-
from typing import Set
9+
from typing import Any, Set
1010

1111
import executorch.backends.vulkan.utils as utils
1212

@@ -190,20 +190,24 @@ def propose_node_layout(
190190
return next(iter(valid_layouts))
191191

192192
def should_annotate(self, node) -> bool:
193-
if not isinstance(node, torch.fx.Node):
194-
return False
195-
196-
if not utils.is_tensor_node(node):
197-
return False
198-
199-
# Storage type and memory layout for tensorref will be determined at runtime
200-
# so there's no use in setting those attributes ahead of time.
201-
if node.meta.get("vkdg_tensorref", False):
202-
return False
203-
204-
# Skip annotating output node. The output tensors should be annotated by the
205-
# time the output node is observed.
206-
if node.op == "output":
193+
if isinstance(node, torch.fx.Node):
194+
if not utils.is_tensor_node(node):
195+
return False
196+
197+
# Storage type and memory layout for tensorref will be determined at runtime
198+
# so there's no use in setting those attributes ahead of time.
199+
if node.meta.get("vkdg_tensorref", False):
200+
return False
201+
202+
# Skip annotating output node. The output tensors should be annotated by the
203+
# time the output node is observed.
204+
if node.op == "output":
205+
return False
206+
elif isinstance(node, (list, tuple)):
207+
return all(
208+
isinstance(n, torch.fx.Node) and self.should_annotate(n) for n in node
209+
)
210+
else:
207211
return False
208212

209213
return True
@@ -215,6 +219,70 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
215219
# time the prepack node is observed.
216220
return node.target == exir_ops.edge.et_vk.prepack.default
217221

222+
def set_or_transition_arg_node(
223+
self,
224+
i: int,
225+
arg: torch.fx.Node,
226+
node: torch.fx.Node,
227+
graph_module: torch.fx.GraphModule,
228+
dirty: bool,
229+
) -> bool:
230+
assert isinstance(arg, torch.fx.Node)
231+
232+
storage = utils.get_node_storage_type(node)
233+
assert storage is not None
234+
layout = utils.get_node_memory_layout(node)
235+
assert layout is not None
236+
237+
arg_storage = utils.get_node_storage_type(arg)
238+
arg_layout = utils.get_node_memory_layout(arg)
239+
240+
if arg_storage is None:
241+
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
242+
arg_storage = storage
243+
if arg_layout is None:
244+
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
245+
arg_layout = layout
246+
247+
if arg_storage == storage and arg_layout == layout:
248+
return False
249+
250+
if not dirty:
251+
logger.info(
252+
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
253+
)
254+
255+
insert_transition_node(graph_module, node, arg, storage, layout)
256+
257+
logger.info(
258+
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
259+
)
260+
261+
return True
262+
263+
def set_or_transition_arg(
264+
self,
265+
i: int,
266+
arg: Any,
267+
node: torch.fx.Node,
268+
graph_module: torch.fx.GraphModule,
269+
dirty: bool,
270+
) -> bool:
271+
if isinstance(arg, torch.fx.Node):
272+
return self.set_or_transition_arg_node(i, arg, node, graph_module, dirty)
273+
elif isinstance(arg, (list, tuple)):
274+
need_transition = False
275+
for arg_node in arg:
276+
need_transition = (
277+
self.set_or_transition_arg_node(
278+
i, arg_node, node, graph_module, need_transition
279+
)
280+
or need_transition
281+
)
282+
return need_transition
283+
else:
284+
return False
285+
218286
# noqa
219287
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
220288
for node in graph_module.graph.nodes:
@@ -226,36 +294,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
226294

227295
set_memory_metadata(node, storage, layout)
228296

229-
inserting_transitions_for_node = False
297+
need_transition = False
230298
for i, arg in enumerate(node.args):
231299
if not self.should_annotate(arg):
232300
continue
233301

234-
assert isinstance(arg, torch.fx.Node)
235-
236-
arg_storage = utils.get_node_storage_type(arg)
237-
arg_layout = utils.get_node_memory_layout(arg)
238-
239-
if arg_storage is None:
240-
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
241-
arg_storage = storage
242-
if arg_layout is None:
243-
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
244-
arg_layout = layout
245-
246-
if arg_storage == storage and arg_layout == layout:
247-
continue
248-
249-
if not inserting_transitions_for_node:
250-
inserting_transitions_for_node = True
251-
logger.info(
252-
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
302+
need_transition = (
303+
self.set_or_transition_arg(
304+
i, arg, node, graph_module, need_transition
253305
)
254-
255-
insert_transition_node(graph_module, node, arg, storage, layout)
256-
257-
logger.info(
258-
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
306+
or need_transition
259307
)
260308

261309
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)