Skip to content

Commit c69a79c

Browse files
authored
Fix self.log(on_epoch=True) on_batch_start (#9780)
1 parent 8c76cf5 commit c69a79c

File tree

8 files changed

+72
-86
lines changed

8 files changed

+72
-86
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
527527
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))
528528

529529

530+
- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780))
531+
532+
530533
- Fixed restoring training state during `trainer.fit` only ([#9413](https://github.com/PyTorchLightning/pytorch-lightning/pull/9413))
531534

532535

533536
- Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788))
534537

538+
535539
- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))
536540

537541

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
2424
from pytorch_lightning.loops.utilities import _get_active_optimizers
2525
from pytorch_lightning.trainer.supporters import TensorRunningAccum
26-
from pytorch_lightning.utilities import AttributeDict
27-
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
28-
from pytorch_lightning.utilities.warnings import WarningCache
2926

3027
_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
3128

@@ -43,7 +40,6 @@ def __init__(self) -> None:
4340
self.manual_loop = ManualOptimization()
4441

4542
self._outputs: _OUTPUTS_TYPE = []
46-
self._warning_cache: WarningCache = WarningCache()
4743
self._remaining_splits: Optional[List[Any]] = None
4844

4945
@property
@@ -59,42 +55,6 @@ def connect(
5955
if manual_loop is not None:
6056
self.manual_loop = manual_loop
6157

62-
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
63-
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
64-
65-
Args:
66-
batch: the current batch to run the train step on
67-
batch_idx: the index of the current batch
68-
"""
69-
if batch is None:
70-
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
71-
return AttributeDict(signal=0, outputs=[])
72-
73-
# hook
74-
self.trainer.logger_connector.on_batch_start()
75-
response = self.trainer.call_hook("on_batch_start")
76-
if response == -1:
77-
return AttributeDict(signal=-1)
78-
79-
# hook
80-
# TODO: Update this in v1.7 (deprecation: #9816)
81-
model_fx = self.trainer.lightning_module.on_train_batch_start
82-
extra_kwargs = (
83-
{"dataloader_idx": 0}
84-
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
85-
else {}
86-
)
87-
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
88-
if response == -1:
89-
return AttributeDict(signal=-1)
90-
91-
self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()
92-
93-
super().run(batch, batch_idx)
94-
95-
output, self._outputs = AttributeDict(signal=0, outputs=self._outputs), None # free memory
96-
return output
97-
9858
def reset(self) -> None:
9959
"""Resets the loop state."""
10060
self._outputs = []
@@ -117,11 +77,10 @@ def advance(self, batch, batch_idx):
11777
batch_idx: the index of the current batch
11878
"""
11979
void(batch)
120-
split_idx, split_batch = self._remaining_splits.pop(0)
121-
self.split_idx = split_idx
80+
self.split_idx, split_batch = self._remaining_splits.pop(0)
12281

12382
# let logger connector extract current batch size
124-
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)
83+
self.trainer.logger_connector.on_train_split_start(self.split_idx, split_batch)
12584

12685
# choose which loop will run the optimization
12786
if self.trainer.lightning_module.automatic_optimization:
@@ -135,10 +94,12 @@ def advance(self, batch, batch_idx):
13594
# then `advance` doesn't finish and an empty dict is returned
13695
self._outputs.append(outputs)
13796

138-
def on_run_end(self) -> None:
97+
def on_run_end(self) -> _OUTPUTS_TYPE:
13998
self.optimizer_loop._hiddens = None
14099
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
141100
self.manual_loop._hiddens = None
101+
output, self._outputs = self._outputs, None # free memory
102+
return output
142103

143104
def teardown(self) -> None:
144105
# release memory

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
233233
Raises:
234234
AssertionError: If the number of dataloaders is None (has not yet been set).
235235
"""
236-
self.trainer.logger_connector.on_batch_start()
236+
self.trainer.logger_connector.on_batch_start(batch_idx)
237237

238238
assert self._num_dataloaders is not None
239-
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders)
239+
self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders)
240240

241241
if self.trainer.testing:
242242
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
2929
from pytorch_lightning.utilities.model_helpers import is_overridden
3030
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
31+
from pytorch_lightning.utilities.warnings import WarningCache
3132

3233
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
3334

@@ -57,6 +58,7 @@ def __init__(self, min_steps: int, max_steps: int):
5758

5859
self._results = ResultCollection(training=True)
5960
self._outputs: _OUTPUTS_TYPE = []
61+
self._warning_cache = WarningCache()
6062
self._dataloader_iter: Optional[Iterator] = None
6163
# caches the loaded dataloader state until dataloader objects are available
6264
self._dataloader_state_dict: Dict[str, Any] = {}
@@ -151,14 +153,37 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
151153

152154
self.batch_progress.increment_ready()
153155

154-
with self.trainer.profiler.profile("run_training_batch"):
155-
batch_output = self.batch_loop.run(batch, batch_idx)
156+
if batch is None:
157+
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
158+
batch_output = []
159+
else:
160+
# hook
161+
self.trainer.logger_connector.on_batch_start(batch_idx)
162+
response = self.trainer.call_hook("on_batch_start")
163+
if response == -1:
164+
self.batch_progress.increment_processed()
165+
raise StopIteration
166+
167+
# TODO: Update this in v1.7 (deprecation: #9816)
168+
model_fx = self.trainer.lightning_module.on_train_batch_start
169+
extra_kwargs = (
170+
{"dataloader_idx": 0}
171+
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
172+
else {}
173+
)
156174

157-
self.batch_progress.increment_processed()
175+
# hook
176+
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
177+
if response == -1:
178+
self.batch_progress.increment_processed()
179+
raise StopIteration
158180

159-
# when returning -1 from train_step, we end epoch early
160-
if batch_output.signal == -1:
161-
raise StopIteration
181+
self.batch_progress.increment_started()
182+
183+
with self.trainer.profiler.profile("run_training_batch"):
184+
batch_output = self.batch_loop.run(batch, batch_idx)
185+
186+
self.batch_progress.increment_processed()
162187

163188
# update non-plateau LR schedulers
164189
# update epoch-interval ones only when we are at the end of training epoch
@@ -167,7 +192,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
167192
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
168193

169194
batch_end_outputs = self._prepare_outputs_training_batch_end(
170-
batch_output.outputs,
195+
batch_output,
171196
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
172197
num_optimizers=len(self.trainer.optimizers),
173198
)
@@ -186,7 +211,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
186211
self.batch_progress.increment_completed()
187212

188213
if is_overridden("training_epoch_end", self.trainer.lightning_module):
189-
self._outputs.append(batch_output.outputs)
214+
self._outputs.append(batch_output)
190215

191216
# -----------------------------------------
192217
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,14 @@ def _increment_eval_log_step(self) -> None:
138138
elif self.trainer.state.stage is RunningStage.TESTING:
139139
self._test_log_step += 1
140140

141-
def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None:
141+
def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None:
142142
model = self.trainer.lightning_module
143143
# set dataloader_idx only if multiple ones
144144
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
145145

146146
# track batch_size
147147
assert self.trainer._results is not None
148148
self.trainer._results.extract_batch_size(batch)
149-
self._batch_idx = batch_idx
150149

151150
def update_eval_step_metrics(self) -> None:
152151
if self.trainer.sanity_checking:
@@ -213,14 +212,12 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
213212
Train metric updates
214213
"""
215214

216-
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
215+
def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
217216
assert self.trainer._results is not None
218217
# when the user requests `dataloader_iter`, we can't track the batch_size
219218
# and this is left to user responsibility.
220219
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
221220
self.trainer._results.extract_batch_size(split_batch)
222-
223-
self._batch_idx = batch_idx
224221
self._split_idx = split_idx
225222

226223
def update_train_step_metrics(self) -> None:
@@ -267,7 +264,8 @@ def _log_gpus_metrics(self) -> None:
267264
def on_epoch_start(self) -> None:
268265
self._epoch_end_reached = False
269266

270-
def on_batch_start(self) -> None:
267+
def on_batch_start(self, batch_idx: int) -> None:
268+
self._batch_idx = batch_idx
271269
self._epoch_end_reached = False
272270

273271
def epoch_end_reached(self) -> None:

tests/loops/test_evaluation_loop_flow.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,8 @@ def backward(self, loss, optimizer, optimizer_idx):
6464
# simulate training manually
6565
trainer.state.stage = RunningStage.TRAINING
6666
batch_idx, batch = 0, next(iter(model.train_dataloader()))
67-
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
68-
assert out.signal == 0
67+
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
6968

70-
train_step_out = out.outputs
7169
assert len(train_step_out) == 1
7270
train_step_out = train_step_out[0][0]
7371
assert isinstance(train_step_out["loss"], torch.Tensor)
@@ -129,10 +127,8 @@ def backward(self, loss, optimizer, optimizer_idx):
129127
trainer.state.stage = RunningStage.TRAINING
130128
# make sure training outputs what is expected
131129
batch_idx, batch = 0, next(iter(model.train_dataloader()))
132-
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
133-
assert out.signal == 0
130+
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
134131

135-
train_step_out = out.outputs
136132
assert len(train_step_out) == 1
137133
train_step_out = train_step_out[0][0]
138134
assert isinstance(train_step_out["loss"], torch.Tensor)

tests/loops/test_training_loop_flow_scalar.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,8 @@ def backward(self, loss, optimizer, optimizer_idx):
147147
trainer.state.stage = RunningStage.TRAINING
148148
# make sure training outputs what is expected
149149
batch_idx, batch = 0, next(iter(model.train_dataloader()))
150-
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
151-
assert out.signal == 0
150+
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
152151

153-
train_step_out = out.outputs
154152
assert len(train_step_out) == 1
155153
train_step_out = train_step_out[0][0]
156154
assert isinstance(train_step_out["loss"], torch.Tensor)
@@ -221,10 +219,8 @@ def backward(self, loss, optimizer, optimizer_idx):
221219
trainer.state.stage = RunningStage.TRAINING
222220
# make sure training outputs what is expected
223221
batch_idx, batch = 0, next(iter(model.train_dataloader()))
224-
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
225-
assert out.signal == 0
222+
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
226223

227-
train_step_out = out.outputs
228224
assert len(train_step_out) == 1
229225
train_step_out = train_step_out[0][0]
230226
assert isinstance(train_step_out["loss"], torch.Tensor)
@@ -311,8 +307,7 @@ def training_step(self, batch, batch_idx):
311307
for batch_idx, batch in enumerate(model.train_dataloader()):
312308
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
313309
if not batch_idx % 2:
314-
assert out.outputs == []
315-
assert out.signal == 0
310+
assert out == []
316311

317312

318313
def test_training_step_none_batches(tmpdir):
@@ -321,7 +316,6 @@ def test_training_step_none_batches(tmpdir):
321316
class TestModel(BoringModel):
322317
def __init__(self):
323318
super().__init__()
324-
325319
self.counter = 0
326320

327321
def collate_none_when_even(self, batch):
@@ -333,12 +327,17 @@ def collate_none_when_even(self, batch):
333327
return result
334328

335329
def train_dataloader(self):
336-
return DataLoader(RandomDataset(32, 64), collate_fn=self.collate_none_when_even)
330+
return DataLoader(RandomDataset(32, 4), collate_fn=self.collate_none_when_even)
331+
332+
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
333+
if batch_idx % 2 == 0:
334+
assert outputs == []
335+
else:
336+
assert outputs
337337

338338
model = TestModel()
339339
trainer = Trainer(
340340
default_root_dir=tmpdir,
341-
limit_train_batches=4,
342341
limit_val_batches=1,
343342
max_epochs=4,
344343
enable_model_summary=False,
@@ -348,12 +347,3 @@ def train_dataloader(self):
348347

349348
with pytest.warns(UserWarning, match=r".*train_dataloader yielded None.*"):
350349
trainer.fit(model)
351-
352-
trainer.state.stage = RunningStage.TRAINING
353-
354-
# manually check a few batches
355-
for batch_idx, batch in enumerate(model.train_dataloader()):
356-
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
357-
if not batch_idx % 2:
358-
assert out.outputs == []
359-
assert out.signal == 0

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,21 @@ def on_train_epoch_start(self, _, pl_module):
276276
pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices
277277
)
278278

279+
def on_batch_start(self, _, pl_module, *__):
280+
self.make_logging(
281+
pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
282+
)
283+
279284
def on_batch_end(self, _, pl_module):
280285
self.make_logging(
281286
pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
282287
)
283288

289+
def on_train_batch_start(self, _, pl_module, *__):
290+
self.make_logging(
291+
pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
292+
)
293+
284294
def on_train_batch_end(self, _, pl_module, *__):
285295
self.make_logging(
286296
pl_module, "on_train_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
@@ -323,7 +333,9 @@ def training_step(self, batch, batch_idx):
323333
"on_train_start": 1,
324334
"on_epoch_start": 1,
325335
"on_train_epoch_start": 1,
336+
"on_train_batch_start": 2,
326337
"on_train_batch_end": 2,
338+
"on_batch_start": 2,
327339
"on_batch_end": 2,
328340
"on_train_epoch_end": 1,
329341
"on_epoch_end": 1,

0 commit comments

Comments
 (0)