Skip to content

Commit 73e0a57

Browse files
carmoccarohitgr7
authored andcommitted
Remove manual tracking of optimizer steps (Lightning-AI#9957)
1 parent 4dcc078 commit 73e0a57

File tree

4 files changed

+22
-20
lines changed

4 files changed

+22
-20
lines changed

pytorch_lightning/core/optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def __init__(self, optimizer: Optimizer):
4646
self._optimizer = optimizer
4747
self._trainer = None
4848
self._optimizer_idx = None
49-
self._total_optimizer_step_calls = 0
5049

5150
@property
5251
def optimizer(self):
@@ -192,7 +191,6 @@ def closure_dis():
192191
trainer = self._trainer
193192
with trainer.profiler.profile(profiler_action):
194193
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
195-
self._total_optimizer_step_calls += 1
196194

197195
def __repr__(self):
198196
groups = [

tests/accelerators/test_tpu.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License
1414
import collections
1515
from copy import deepcopy
16+
from unittest.mock import patch
1617

1718
import pytest
1819
import torch
@@ -21,7 +22,6 @@
2122
from pytorch_lightning import Trainer
2223
from pytorch_lightning.accelerators.cpu import CPUAccelerator
2324
from pytorch_lightning.accelerators.tpu import TPUAccelerator
24-
from pytorch_lightning.callbacks import Callback
2525
from pytorch_lightning.plugins import TPUSpawnPlugin
2626
from pytorch_lightning.utilities import find_shared_parameters
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -189,16 +189,18 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
189189
assert torch.all(self.layer.weight.grad == 0)
190190
self.count += 1
191191

192+
def on_train_start(self):
193+
opt = self.optimizers()
194+
self.opt_step_patch = patch.object(opt, "step", wraps=opt.step)
195+
self.opt_step_mock = self.opt_step_patch.start()
196+
192197
def on_train_end(self):
193198
assert self.called["training_step"] == 5
194199
assert self.called["on_train_batch_start"] == 5
195200
assert self.called["on_train_batch_end"] == 5
196201

197-
class TestManualOptimizationCallack(Callback):
198-
def on_train_end(self, trainer, pl_module):
199-
200-
opt = pl_module.optimizers()
201-
assert opt._total_optimizer_step_calls == 3
202+
self.opt_step_patch.stop()
203+
assert self.opt_step_mock.call_count == 3
202204

203205
model = ManualOptimizationModel()
204206
model_copy = deepcopy(model)
@@ -212,7 +214,6 @@ def on_train_end(self, trainer, pl_module):
212214
limit_test_batches=0,
213215
limit_val_batches=0,
214216
tpu_cores=8,
215-
callbacks=[TestManualOptimizationCallack()],
216217
)
217218
trainer.fit(model)
218219

tests/core/test_lightning_optimizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ def test_state(tmpdir):
161161
"zero_grad",
162162
"__setstate__",
163163
"add_param_group",
164-
"_total_optimizer_step_calls",
165164
]
166165

167166
for k, v in lightning_optimizer.__dict__.items():

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from pytorch_lightning import seed_everything, Trainer
2525
from pytorch_lightning.accelerators import Accelerator
26-
from pytorch_lightning.callbacks import Callback
2726
from tests.helpers.boring_model import BoringModel
2827
from tests.helpers.runif import RunIf
2928

@@ -706,14 +705,6 @@ def configure_optimizers(self):
706705
mock_adam_step.assert_has_calls(expected_calls)
707706

708707

709-
class TestManualOptimizationDDPCallack(Callback):
710-
def on_train_end(self, trainer, pl_module):
711-
712-
opt_a, opt_b = pl_module.optimizers()
713-
assert opt_a._total_optimizer_step_calls == 4
714-
assert opt_b._total_optimizer_step_calls == 2
715-
716-
717708
class TesManualOptimizationDDPModel(BoringModel):
718709
def __init__(self):
719710
super().__init__()
@@ -787,6 +778,20 @@ def configure_optimizers(self):
787778
optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001)
788779
return [optimizer_gen, optimizer_dis]
789780

781+
def on_train_start(self):
782+
# this is done here instead of in the calling function due to `spawn`
783+
sgd, adam = self.optimizers()
784+
self.sgd_step_patch = patch.object(sgd, "step", wraps=sgd.step)
785+
self.sgd_step_mock = self.sgd_step_patch.start()
786+
self.adam_step_patch = patch.object(adam, "step", wraps=adam.step)
787+
self.adam_step_mock = self.adam_step_patch.start()
788+
789+
def on_train_end(self):
790+
self.sgd_step_patch.stop()
791+
assert self.sgd_step_mock.call_count == 4
792+
self.adam_step_patch.stop()
793+
assert self.adam_step_mock.call_count == 2
794+
790795

791796
def train_manual_optimization(tmpdir, strategy, model_cls=TesManualOptimizationDDPModel):
792797

@@ -806,7 +811,6 @@ def train_manual_optimization(tmpdir, strategy, model_cls=TesManualOptimizationD
806811
log_every_n_steps=1,
807812
gpus=2,
808813
strategy=strategy,
809-
callbacks=[TestManualOptimizationDDPCallack()],
810814
)
811815

812816
trainer.fit(model)

0 commit comments

Comments
 (0)