Skip to content

Commit 788f686

Browse files
authored
Fix LightningOptimizer step and toggling logic (#9958)
1 parent 7b4df7b commit 788f686

File tree

3 files changed

+25
-34
lines changed

3 files changed

+25
-34
lines changed

pytorch_lightning/core/optimizer.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,6 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
9393
optimizer = trainer.lightning_optimizers[opt_idx]
9494
return optimizer
9595

96-
def _toggle_model(self):
97-
model_ref = self._trainer.lightning_module
98-
model_ref.toggle_optimizer(self, self._optimizer_idx)
99-
100-
def _untoggle_model(self):
101-
model_ref = self._trainer.lightning_module
102-
model_ref.untoggle_optimizer(self)
103-
10496
@contextmanager
10597
def toggle_model(self, sync_grad: bool = True):
10698
"""This function is just a helper for advanced users.
@@ -116,16 +108,12 @@ def toggle_model(self, sync_grad: bool = True):
116108
# local import here to avoid circular import
117109
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
118110

111+
lightning_module = self._trainer.lightning_module
112+
119113
with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)):
120-
self._toggle_model()
114+
lightning_module.toggle_optimizer(self, self._optimizer_idx)
121115
yield
122-
self._untoggle_model()
123-
124-
def __optimizer_step(self, closure: Callable, profiler_name: str = None, **kwargs):
125-
trainer = self._trainer
126-
127-
with trainer.profiler.profile(profiler_name):
128-
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
116+
lightning_module.untoggle_optimizer(self._optimizer_idx)
129117

130118
def step(self, closure: Optional[Callable] = None, **kwargs):
131119
"""Call this directly from your training_step when doing optimizations manually. By using this we can
@@ -193,14 +181,17 @@ def closure_dis():
193181
opt_dis.step(closure=closure_dis)
194182
"""
195183
if closure is None:
196-
profiler_name = f"closure_{self._optimizer_idx}"
197184
closure = do_nothing_closure
185+
profiler_action = "optimizer_step_without_closure"
186+
elif not callable(closure):
187+
raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
198188
else:
199-
if not callable(closure):
200-
raise MisconfigurationException("When closure is provided, it should be a function")
201-
profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
189+
profiler_action = "optimizer_step_with_closure"
190+
profiler_action += f"_{self._optimizer_idx}"
202191

203-
self.__optimizer_step(closure=closure, profiler_name=profiler_name, **kwargs)
192+
trainer = self._trainer
193+
with trainer.profiler.profile(profiler_action):
194+
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
204195
self._total_optimizer_step_calls += 1
205196

206197
def __repr__(self):

pytorch_lightning/profiler/pytorch.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ def pre_step(self, current_action: str) -> None:
117117

118118
def reset(self):
119119
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
120-
self._num_optimizer_step_and_closure = 0
120+
self._num_optimizer_step_with_closure = 0
121121
self._num_validation_step = 0
122122
self._num_test_step = 0
123123
self._num_predict_step = 0
124-
self._optimizer_step_and_closure_reached_end = False
124+
self._optimizer_step_with_closure_reached_end = False
125125
self._validation_step_reached_end = False
126126
self._test_step_reached_end = False
127127
self._predict_step_reached_end = False
@@ -132,13 +132,13 @@ def reset(self):
132132
@property
133133
def is_training(self) -> bool:
134134
return self._current_action is not None and (
135-
self._current_action.startswith("optimizer_step_and_closure_") or self._current_action == "training_step"
135+
self._current_action.startswith("optimizer_step_with_closure_") or self._current_action == "training_step"
136136
)
137137

138138
@property
139139
def num_step(self) -> int:
140140
if self.is_training:
141-
return self._num_optimizer_step_and_closure
141+
return self._num_optimizer_step_with_closure
142142
if self._current_action == "validation_step":
143143
return self._num_validation_step
144144
if self._current_action == "test_step":
@@ -149,10 +149,10 @@ def num_step(self) -> int:
149149

150150
def _step(self) -> None:
151151
if self.is_training:
152-
self._num_optimizer_step_and_closure += 1
152+
self._num_optimizer_step_with_closure += 1
153153
elif self._current_action == "validation_step":
154154
if self._start_action_name == "on_fit_start":
155-
if self._num_optimizer_step_and_closure > 0:
155+
if self._num_optimizer_step_with_closure > 0:
156156
self._num_validation_step += 1
157157
else:
158158
self._num_validation_step += 1
@@ -164,7 +164,7 @@ def _step(self) -> None:
164164
@property
165165
def has_finished(self) -> bool:
166166
if self.is_training:
167-
return self._optimizer_step_and_closure_reached_end
167+
return self._optimizer_step_with_closure_reached_end
168168
if self._current_action == "validation_step":
169169
return self._validation_step_reached_end
170170
if self._current_action == "test_step":
@@ -182,7 +182,7 @@ def __call__(self, num_step: int) -> "ProfilerAction":
182182
action = self._schedule(max(self.num_step, 0))
183183
if action == ProfilerAction.RECORD_AND_SAVE:
184184
if self.is_training:
185-
self._optimizer_step_and_closure_reached_end = True
185+
self._optimizer_step_with_closure_reached_end = True
186186
elif self._current_action == "validation_step":
187187
self._validation_step_reached_end = True
188188
elif self._current_action == "test_step":
@@ -202,9 +202,9 @@ class PyTorchProfiler(BaseProfiler):
202202
"test_step",
203203
"predict_step",
204204
}
205-
RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_"
205+
RECORD_FUNCTION_PREFIX = "optimizer_step_with_closure_"
206206
STEP_FUNCTIONS = {"training_step", "validation_step", "test_step", "predict_step"}
207-
STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_"
207+
STEP_FUNCTION_PREFIX = "optimizer_step_with_closure_"
208208
AVAILABLE_SORT_KEYS = {
209209
"cpu_time",
210210
"cuda_time",
@@ -383,8 +383,8 @@ def start(self, action_name: str) -> None:
383383
self._register.__enter__()
384384

385385
if self._lightning_module is not None:
386-
# when the model is used in automatic optimization,
387-
# we use `optimizer_step_and_closure` to step the model.
386+
# when the model is used in automatic optimization, we use `optimizer_step_with_closure` to step the model.
387+
# this profiler event is generated in the `LightningOptimizer.step` method
388388
if self._lightning_module.automatic_optimization and "training_step" in self.STEP_FUNCTIONS:
389389
self.STEP_FUNCTIONS.remove("training_step")
390390

tests/profiler/test_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
285285
files = [file for file in files if file.endswith(".json")]
286286
assert len(files) == 2, files
287287
local_rank = trainer.local_rank
288-
assert any(f"{local_rank}-optimizer_step_and_closure_" in f for f in files)
288+
assert any(f"{local_rank}-optimizer_step_with_closure_" in f for f in files)
289289
assert any(f"{local_rank}-validation_step" in f for f in files)
290290

291291

0 commit comments

Comments
 (0)