Skip to content

Commit 6e600d0

Browse files
cptspacemanspiffDannyYuyang-quic
authored andcommitted
Changes to allow the export of functions with no user input. (pytorch#8031)
### Summary This is a set of changes that I have been using to get executorch working with Multiple Entrypoints with shared state. This is related to: pytorch#8030 ### Issue & Changes Currently trying to export a function: ```python def init(): self.shared_state.fill_(0) ``` Will fail, because in to_edge in exir in the program.py the lifted constants are assumed to be placed before the first user input. This causes weird behavior when there are no user inputs. I changed it to place the lifted constant after the last buffer, if there are no user inputs to place before. However, this relies on my understanding of the implicit layout of the graph inputs. At the same time later in the after the memory planning phase of to_executorch, it validates that memory planning was correct based on whether the `graph_input_allocated` flag is set, this only applies to user inputs, of which we have none, so it errors out. I added a check to bypass this error if there are no user inputs, but I honestly do not understand enough of the validation check to know if that is appropriate. ### Comments In the current executorch with no support for shared state, this case does not make sense, but pytorch#8030 is my attempt at adding that capability. and having initialization methods that init the buffers from constants is useful, especially since their initail state is undefined. Currently this is not ready, I have no tests/ and have my random notes in the commit as comments, and other than validating that it worked as I was working on the shared state export have not done anything in depth.... But I kind-of want feedback if my solution seems correct, or if I am missing something. Particularly regarding my understanding of placeholder ordering and signature logic, and whether bypassing the graph_input_allocated validation is appropriate. Out of all the changes I had to make for shared state, this is the one I am least sure about. cc @JacobSzwejbka @angelayi
1 parent b142651 commit 6e600d0

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

exir/memory_planning.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from executorch.exir.tensor import TensorSpec
2525

2626
from torch import fx
27-
from torch.export.exported_program import ExportGraphSignature
27+
from torch.export.exported_program import ExportGraphSignature, InputKind
2828
from torch.fx import Node
2929
from torch.utils._pytree import tree_flatten
3030

@@ -247,7 +247,19 @@ def verify_graph_input_output(self) -> None:
247247
graph_output_allocated = allocated
248248
has_dynamic_unbound_output |= has_dynamic_unbound_tensor
249249

250-
if "placeholder" in check_list:
250+
# only check if inputs are allocated if there are user inputs:
251+
user_inputs_exist = (
252+
len(
253+
list(
254+
filter(
255+
lambda input: input.kind == InputKind.USER_INPUT,
256+
self.graph_signature.input_specs,
257+
)
258+
)
259+
)
260+
) > 0
261+
262+
if "placeholder" in check_list and user_inputs_exist:
251263
assert graph_input_allocated is not None, "graph_input_allocated not set"
252264
if not has_dynamic_unbound_input:
253265
assert (

exir/program/_program.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ def lift_constant_tensor_pass(ep):
322322
new_input_specs.extend(lifted_constants)
323323
lifted_constants.clear()
324324
new_input_specs.append(s)
325+
if len(lifted_constants) > 0:
326+
new_input_specs = lifted_constants + new_input_specs
325327
ep.graph_signature.input_specs = new_input_specs
326328
ep.graph_module.recompile()
327329
return ep

exir/tests/test_passes.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10571057
new_ep.graph_module.code
10581058
)
10591059

1060+
def test_pass_no_user_inputs(self) -> None:
1061+
class NoUserInputs(torch.nn.Module):
1062+
def __init__(self):
1063+
super().__init__()
1064+
self.register_buffer("a", torch.ones(1))
1065+
1066+
def forward(self) -> torch.Tensor:
1067+
return 3 + self.a
1068+
1069+
mod = NoUserInputs()
1070+
exported_program = export(mod, (), strict=True)
1071+
edge = to_edge(
1072+
exported_program,
1073+
compile_config=EdgeCompileConfig(_skip_dim_order=False),
1074+
)
1075+
ep = edge.exported_program()
1076+
# because there is no user input, the lifted constant should be the first input.
1077+
FileCheck().check("_lifted_tensor_constant1").check(
1078+
"b_a" # followed by the buffer input.
1079+
).run(ep.graph_module.code)
1080+
1081+
# the graph signature should also be the same:
1082+
self.assertEqual(
1083+
ep.graph_signature.input_specs[0].arg.name, "_lifted_tensor_constant1"
1084+
)
1085+
self.assertEqual(ep.graph_signature.input_specs[1].arg.name, "b_a")
1086+
1087+
# Validate that the program successfully passes validation to executorch:
1088+
edge.to_executorch()
1089+
10601090
def test_constant_prop_pass_for_parameter(self) -> None:
10611091
def count_additions(gm: torch.fx.GraphModule) -> int:
10621092
return sum(

0 commit comments

Comments
 (0)