Skip to content

Commit 1facfa9

Browse files
Allow emitting mutable buffer names in schema
Differential Revision: D72579501 Pull Request resolved: #9935
1 parent 5c2b693 commit 1facfa9

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

exir/capture/_config.py

+5
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,8 @@ class ExecutorchBackendConfig:
9797
# If set to true, all trainable weights will be stored in a separate file,
9898
# external to the PTE file.
9999
external_mutable_weights: bool = False
100+
101+
# If set to true, all mutable buffers will have their fully qualified names
102+
# serialized in the PTE file. Its value is ignored if mutable buffers are not
103+
# memory planned as the names must be serialized in that case.
104+
emit_mutable_buffer_names: bool = False

exir/emit/_emit_program.py

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def emit_program(
118118
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
119119
emit_stacktrace: bool = False,
120120
prim_getters: Optional[Dict[str, Any]] = None,
121+
emit_mutable_buffer_names: bool = False,
121122
) -> EmitterOutput:
122123
"""
123124
Given a exported program, it returns the program in the format
@@ -163,6 +164,7 @@ def emit_program(
163164
operator_cache={},
164165
delegate_cache={},
165166
emit_stacktrace=emit_stacktrace,
167+
emit_mutable_buffer_names=emit_mutable_buffer_names,
166168
)
167169

168170
gm = _remove_non_user_outputs(exported_program)

exir/emit/_emitter.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class _EmitterState:
149149
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
150150
delegate_cache: Dict[str, int]
151151
emit_stacktrace: bool
152+
emit_mutable_buffer_names: bool
152153

153154
spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict)
154155

@@ -1610,7 +1611,7 @@ def _find_fqn_for_placeholder(
16101611
)
16111612
return fqn, is_mutable_buffer
16121613

1613-
def placeholder(
1614+
def placeholder( # noqa: C901
16141615
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
16151616
) -> _AbstractValue:
16161617
"""Emits the value within the placeholder node.
@@ -1639,6 +1640,13 @@ def placeholder(
16391640
else:
16401641
spec.extra_tensor_info.fully_qualified_name = fqn
16411642
spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL
1643+
if self.emitter_state.emit_mutable_buffer_names and is_mutable_buffer:
1644+
if spec.extra_tensor_info is None:
1645+
spec.extra_tensor_info = ExtraTensorInfo(
1646+
fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT
1647+
)
1648+
else:
1649+
spec.extra_tensor_info.fully_qualified_name = fqn
16421650

16431651
# From the fqn find the corresponding tensor
16441652
real_tensor = None

exir/emit/test/test_emit.py

+24
Original file line numberDiff line numberDiff line change
@@ -1819,3 +1819,27 @@ def forward(self, input, label):
18191819
]
18201820
self.assertEqual(external_map["net.linear.weight"], 0)
18211821
self.assertEqual(external_map["net.linear.bias"], 1)
1822+
1823+
def test_emit_mutable_buffer_names(self) -> None:
1824+
class Net(nn.Module):
1825+
def __init__(self):
1826+
super().__init__()
1827+
self.linear = nn.Linear(2, 2)
1828+
self.register_buffer("buffer", torch.zeros(1, 2))
1829+
1830+
def forward(self, x):
1831+
self.buffer.add_(1)
1832+
return self.linear(x) + self.buffer
1833+
1834+
net = Net()
1835+
1836+
ep = export(net, (torch.randn(1, 2),), strict=True)
1837+
# Lower the graph to edge dialect.
1838+
ep = to_edge(ep)
1839+
# Lower the graph to executorch.
1840+
ep = ep.to_executorch(
1841+
config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)
1842+
)
1843+
for val in ep.executorch_program.execution_plan[0].values:
1844+
if isinstance(val, Tensor) and val.extra_tensor_info:
1845+
self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer")

exir/program/_program.py

+1
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,7 @@ def __init__(
16121612
self._execution_programs,
16131613
backend_config.emit_stacktrace,
16141614
self._config_methods,
1615+
backend_config.emit_mutable_buffer_names,
16151616
)
16161617

16171618
# Serialize emitter output, ready to be written to a file.

0 commit comments

Comments
 (0)