Skip to content

Commit 566484b

Browse files
cccclaifacebook-github-bot
authored andcommitted
disable disables gradient calculation when getting the prop constant tensor (#3948)
Summary: Pull Request resolved: #3948 As title, to fix the error when exporting/lowering lstm, the error message is: ``` Cell In[13], line 3 1 from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner ----> 3 edge_manager = edge_manager.to_backend(XnnpackPartitioner()) 5 print(edge_manager.exported_program()) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/executorch/exir/program/_program.py:1166, in EdgeProgramManager.to_backend(self, partitioner) 1164 else: # apply partitioner to every method 1165 for name, program in self._edge_programs.items(): -> 1166 new_edge_programs[name] = to_backend(program, partitioner) 1168 config = EdgeCompileConfig(_check_ir_validity=False) 1169 return EdgeProgramManager( 1170 new_edge_programs, copy.deepcopy(self._config_methods), config 1171 ) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw) 885 if not args: 886 raise TypeError(f'{funcname} requires at least ' 887 '1 positional argument') --> 889 return dispatch(args[0].__class__)(*args, **kw) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/executorch/exir/backend/backend_api.py:384, in _(edge_program, partitioner_instance) 381 for tag, _ in partitioner_result.partition_tags.items(): 382 _maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program) --> 384 tagged_graph_module = _partition_and_lower( 385 tagged_exported_program.graph_module, partitioner_result, edge_program 386 ) 388 # TODO(angelayi): Update this signature in a less manual way (maybe through 389 # retracing) 390 new_signature, new_state_dict, new_constants = _get_new_signature( 391 edge_program, 392 tagged_graph_module, 393 ) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/executorch/exir/backend/backend_api.py:299, in _partition_and_lower(tagged_graph_module, partition_result, owning_program) 290 def _partition_and_lower( 291 tagged_graph_module: torch.fx.GraphModule, 292 partition_result: PartitionResult, 293 owning_program: ExportedProgram, 294 ) -> torch.fx.GraphModule: 295 """ 296 Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow 297 """ --> 299 partitioned_module = _partition_and_lower_one_graph_module( 300 tagged_graph_module, partition_result, owning_program 301 ) 303 # Recursively partition and lower for submodules 304 for name, submod, _node in get_control_flow_submodules(partitioned_module): File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/executorch/exir/backend/backend_api.py:230, in _partition_and_lower_one_graph_module(tagged_graph_module, partition_result, owning_program) 224 logging.debug(f"Partitioned graph module: {tagged_graph_module}") 226 submodule_program = create_exported_program_from_submodule( 227 submodule, owning_program, tag 228 ) --> 230 lowered_submodule = to_backend( 231 delegation_spec.backend_id, 232 submodule_program, 233 delegation_spec.compile_specs, 234 ) 236 # call delegate args should only use user_inputs 237 call_delegate_args = [] File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw) 885 if not args: 886 raise TypeError(f'{funcname} requires at least ' 887 '1 positional argument') --> 889 return dispatch(args[0].__class__)(*args, **kw) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/executorch/exir/backend/backend_api.py:113, in _(backend_id, edge_program, compile_specs) 111 for cls in BackendDetails.__subclasses__(): 112 if backend_id == cls.__name__: --> 113 copied_edge_program = copy.deepcopy(edge_program) 114 preprocess_result: PreprocessResult = cls.preprocess( 115 copied_edge_program, 116 compile_specs, 117 ) 118 lowered_module = LoweredBackendModule( 119 edge_program=edge_program, 120 backend_id=backend_id, 121 processed_bytes=preprocess_result.processed_bytes, 122 compile_specs=compile_specs, 123 ) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/copy.py:172, in deepcopy(x, memo, _nil) 170 y = x 171 else: --> 172 y = _reconstruct(x, memo, *rv) 174 # If is its own copy, don't memoize. 175 if y is not x: File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/copy.py:271, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy) 269 if state is not None: 270 if deep: --> 271 state = deepcopy(state, memo) 272 if hasattr(y, '__setstate__'): 273 y.__setstate__(state) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/copy.py:146, in deepcopy(x, memo, _nil) 144 copier = _deepcopy_dispatch.get(cls) 145 if copier is not None: --> 146 y = copier(x, memo) 147 else: 148 if issubclass(cls, type): File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/copy.py:231, in _deepcopy_dict(x, memo, deepcopy) 229 memo[id(x)] = y 230 for key, value in x.items(): --> 231 y[deepcopy(key, memo)] = deepcopy(value, memo) 232 return y File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/copy.py:146, in deepcopy(x, memo, _nil) 144 copier = _deepcopy_dispatch.get(cls) 145 if copier is not None: --> 146 y = copier(x, memo) 147 else: 148 if issubclass(cls, type): File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/copy.py:231, in _deepcopy_dict(x, memo, deepcopy) 229 memo[id(x)] = y 230 for key, value in x.items(): --> 231 y[deepcopy(key, memo)] = deepcopy(value, memo) 232 return y File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/runtime/lib/python3.10/copy.py:153, in deepcopy(x, memo, _nil) 151 copier = getattr(x, "__deepcopy__", None) 152 if copier is not None: --> 153 y = copier(memo) 154 else: 155 reductor = dispatch_table.get(cls) File /mnt/xarfuse/uid-27416/e8d86d8d-seed-nspid4026533405_cgpid10356714-ns-4026533402/torch/_tensor.py:86, in Tensor.__deepcopy__(self, memo) 84 return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo) 85 if not self.is_leaf: ---> 86 raise RuntimeError( 87 "Only Tensors created explicitly by the user " 88 "(graph leaves) support the deepcopy protocol at the moment. " 89 "If you were attempting to deepcopy a module, this may be because " 90 "of a torch.nn.utils.weight_norm usage, " 91 "see pytorch/pytorch#103001" 92 ) 93 if id(self) in memo: 94 return memo[id(self)] RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see pytorch/pytorch#103001 ``` The reason is the constant tensor has the grad_fn when run it without the contextmgr Reviewed By: angelayi Differential Revision: D58436236
1 parent 1345bc2 commit 566484b

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@ def get_propagated_const_tensor_dict(
133133
lambda x: get_data(x, exported_program, const_node_to_tensor),
134134
(node.args, node.kwargs),
135135
)
136-
137-
# Execute the `node.target` and create a new propagated constant tensor.
138-
prop_constant_tensor = node.target(*args_data, **kwargs_data)
136+
# Disable grad for constant propagation, otherwise the generated tensor can't be copied
137+
# because of the grad_fn.
138+
with torch.no_grad():
139+
# Execute the `node.target` and create a new propagated constant tensor.
140+
prop_constant_tensor = node.target(*args_data, **kwargs_data)
139141
const_node_to_tensor[node] = prop_constant_tensor
140142

141143
return const_node_to_tensor

exir/tests/test_passes.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
import copy
89
import os
910
import tempfile
1011
import unittest
@@ -1639,3 +1640,36 @@ def forward(self, x):
16391640
FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
16401641
gm.code
16411642
)
1643+
1644+
def test_constant_prop_pass_for_no_grad(self) -> None:
1645+
class LSTM(torch.nn.Module):
1646+
def __init__(self, input_size, hidden_size, num_layers):
1647+
super(LSTM, self).__init__()
1648+
self.hidden_size = hidden_size
1649+
self.num_layers = num_layers
1650+
self.lstm = torch.nn.LSTM(
1651+
input_size, hidden_size, num_layers, batch_first=True
1652+
)
1653+
1654+
def forward(self, text_tokens):
1655+
# input: (seq_len, batch, input_size)
1656+
lstm_out, (new_hidden_state, new_cell_state) = self.lstm(
1657+
input=text_tokens, hx=None
1658+
)
1659+
return lstm_out
1660+
1661+
lstm = LSTM(input_size=200, hidden_size=203, num_layers=2)
1662+
example_input = (torch.rand(2, 10, 200),)
1663+
1664+
aten = torch.export.export(lstm, example_input, strict=False)
1665+
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
1666+
_check_ir_validity=True,
1667+
_skip_dim_order=True, # TODO(T189114319): Reuse dim order op after solving the ios oss issue
1668+
)
1669+
1670+
edge_manager: EdgeProgramManager = to_edge(
1671+
aten,
1672+
compile_config=_EDGE_COMPILE_CONFIG,
1673+
)
1674+
new_ep = constant_prop_pass(edge_manager._edge_programs["forward"])
1675+
_ = copy.deepcopy(new_ep.module_call_graph)

0 commit comments

Comments
 (0)