Skip to content

Commit f7115c2

Browse files
Fix PyTorch stateful RNN/LSTM gradient computation error resolves #20875 (#20916)
* Fix PyTorch stateful RNN gradient computation error * Updates post feedback
1 parent 7a7bca6 commit f7115c2

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

keras/src/layers/rnn/rnn.py

+6
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,12 @@ def inner_loop(self, sequences, initial_state, mask, training=False):
331331
cell_kwargs["training"] = training
332332

333333
def step(inputs, states):
334+
# Create new tensor copies when using PyTorch backend
335+
# with stateful=True. This prevents in-place modifications
336+
# that would otherwise break PyTorch's autograd functionality
337+
# by modifying tensors needed for gradient computation.
338+
if backend.backend() == "torch" and self.stateful:
339+
states = tree.map_structure(ops.copy, states)
334340
output, new_states = self.cell(inputs, states, **cell_kwargs)
335341
if not tree.is_nested(new_states):
336342
new_states = [new_states]

0 commit comments

Comments
 (0)