Skip to content

Commit 6013bef

Browse files
committed
add back del logits
1 parent fb11413 commit 6013bef

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

recipes/lora_finetune_single_device.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,9 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
543543
logits = logits.transpose(1, 2)
544544
# Compute loss
545545
loss = self._loss_fn(logits, labels)
546+
# free logits otherwise it peaks backward memory
547+
del logits
548+
546549
return loss
547550

548551
def train(self) -> None:

0 commit comments

Comments
 (0)