Skip to content

Commit 4740be2

Browse files
committed
Fix self.log(on_epoch=True) on_batch_start
1 parent ab20792 commit 4740be2

File tree

2 files changed

+35
-30
lines changed

2 files changed

+35
-30
lines changed

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self) -> None:
4444
self._outputs: _OUTPUTS_TYPE = []
4545
self._warning_cache: WarningCache = WarningCache()
4646
self._remaining_splits: Optional[List[Any]] = None
47+
self._exit_signal: int = 0
4748

4849
@property
4950
def done(self) -> bool:
@@ -58,35 +59,6 @@ def connect(
5859
if manual_loop is not None:
5960
self.manual_loop = manual_loop
6061

61-
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
62-
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
63-
64-
Args:
65-
batch: the current batch to run the train step on
66-
batch_idx: the index of the current batch
67-
"""
68-
if batch is None:
69-
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
70-
return AttributeDict(signal=0, outputs=[])
71-
72-
# hook
73-
self.trainer.logger_connector.on_batch_start()
74-
response = self.trainer.call_hook("on_batch_start")
75-
if response == -1:
76-
return AttributeDict(signal=-1)
77-
78-
# hook
79-
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
80-
if response == -1:
81-
return AttributeDict(signal=-1)
82-
83-
self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()
84-
85-
super().run(batch, batch_idx)
86-
87-
output, self._outputs = AttributeDict(signal=0, outputs=self._outputs), None # free memory
88-
return output
89-
9062
def reset(self) -> None:
9163
"""Resets the loop state."""
9264
self._outputs = []
@@ -108,13 +80,31 @@ def advance(self, batch, batch_idx):
10880
batch: the current batch to run the training on (this is not the split!)
10981
batch_idx: the index of the current batch
11082
"""
111-
void(batch)
83+
if batch is None:
84+
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
85+
raise StopIteration
86+
11287
split_idx, split_batch = self._remaining_splits.pop(0)
11388
self.split_idx = split_idx
11489

11590
# let logger connector extract current batch size
11691
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)
11792

93+
# hook
94+
self.trainer.logger_connector.on_batch_start()
95+
response = self.trainer.call_hook("on_batch_start")
96+
if response == -1:
97+
self._exit_signal = -1
98+
raise StopIteration
99+
100+
# hook
101+
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
102+
if response == -1:
103+
self._exit_signal = -1
104+
raise StopIteration
105+
106+
self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()
107+
118108
# choose which loop will run the optimization
119109
if self.trainer.lightning_module.automatic_optimization:
120110
optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
@@ -131,6 +121,9 @@ def on_run_end(self) -> None:
131121
self.optimizer_loop._hiddens = None
132122
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
133123
self.manual_loop._hiddens = None
124+
output, self._outputs = AttributeDict(signal=self._exit_signal, outputs=self._outputs), None # free memory
125+
self._exit_signal = 0
126+
return output
134127

135128
def teardown(self) -> None:
136129
# release memory

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,21 @@ def on_train_epoch_start(self, _, pl_module):
276276
pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices
277277
)
278278

279+
def on_batch_start(self, _, pl_module, *__):
280+
self.make_logging(
281+
pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
282+
)
283+
279284
def on_batch_end(self, _, pl_module):
280285
self.make_logging(
281286
pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
282287
)
283288

289+
def on_train_batch_start(self, _, pl_module, *__):
290+
self.make_logging(
291+
pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
292+
)
293+
284294
def on_train_batch_end(self, _, pl_module, *__):
285295
self.make_logging(
286296
pl_module, "on_train_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
@@ -323,7 +333,9 @@ def training_step(self, batch, batch_idx):
323333
"on_train_start": 1,
324334
"on_epoch_start": 1,
325335
"on_train_epoch_start": 1,
336+
"on_train_batch_start": 2,
326337
"on_train_batch_end": 2,
338+
"on_batch_start": 2,
327339
"on_batch_end": 2,
328340
"on_train_epoch_end": 1,
329341
"on_epoch_end": 1,

0 commit comments

Comments
 (0)