Skip to content

Commit 4617615

Browse files
awaelchlicarmocca
andcommitted
Update the TQDM progress bar on_train_epoch_end (#11069)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent b75434e commit 4617615

File tree

4 files changed

+94
-20
lines changed

4 files changed

+94
-20
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))
1414
- Fixed running sanity check with `RichProgressBar` ([#10913](https://github.com/PyTorchLightning/pytorch-lightning/pull/10913))
1515
- Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994))
16+
- The TQDM progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069))
17+
- Fixed bug where the TQDM updated the training progress bar during `trainer.validate` ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069))
1618

1719

1820
## [1.5.5] - 2021-12-07

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
else:
2626
from tqdm import tqdm as _tqdm
2727

28+
import pytorch_lightning as pl
2829
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
2930

3031
_PAD_SIZE = 5
@@ -206,12 +207,10 @@ def init_test_tqdm(self) -> Tqdm:
206207
return bar
207208

208209
def on_sanity_check_start(self, trainer, pl_module):
209-
super().on_sanity_check_start(trainer, pl_module)
210210
self.val_progress_bar = self.init_sanity_tqdm()
211211
self.main_progress_bar = Tqdm(disable=True) # dummy progress bar
212212

213213
def on_sanity_check_end(self, trainer, pl_module):
214-
super().on_sanity_check_end(trainer, pl_module)
215214
self.main_progress_bar.close()
216215
self.val_progress_bar.close()
217216

@@ -233,49 +232,59 @@ def on_train_epoch_start(self, trainer, pl_module):
233232

234233
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
235234
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
236-
total_batches = self.total_train_batches + self.total_val_batches
237-
total_batches = convert_inf(total_batches)
238-
if self._should_update(self.train_batch_idx, total_batches):
235+
if self._should_update(self.train_batch_idx):
239236
self._update_bar(self.main_progress_bar)
240237
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
241238

239+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
240+
if self.is_enabled:
241+
self._update_bar(self.main_progress_bar)
242+
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
243+
244+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
245+
self.main_progress_bar.close()
246+
242247
def on_validation_start(self, trainer, pl_module):
243248
super().on_validation_start(trainer, pl_module)
244249
if trainer.sanity_checking:
245250
reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx)
246251
else:
247-
self._update_bar(self.main_progress_bar) # fill up remaining
252+
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
253+
self._update_bar(self.main_progress_bar) # fill up remaining
248254
self.val_progress_bar = self.init_validation_tqdm()
249255
reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx)
250256

251257
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
252258
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
253-
if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)):
259+
if self._should_update(self.val_batch_idx):
260+
self._update_bar(self.val_progress_bar)
261+
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
262+
self._update_bar(self.main_progress_bar)
263+
264+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
265+
if self.is_enabled:
254266
self._update_bar(self.val_progress_bar)
255-
self._update_bar(self.main_progress_bar)
256267

257268
def on_validation_end(self, trainer, pl_module):
258-
super().on_validation_end(trainer, pl_module)
259-
if self.main_progress_bar is not None:
269+
if self.main_progress_bar is not None and trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
260270
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
261271
self.val_progress_bar.close()
262272

263-
def on_train_end(self, trainer, pl_module):
264-
super().on_train_end(trainer, pl_module)
265-
self.main_progress_bar.close()
266-
267273
def on_test_start(self, trainer, pl_module):
268274
super().on_test_start(trainer, pl_module)
269275
self.test_progress_bar = self.init_test_tqdm()
270276
self.test_progress_bar.total = convert_inf(self.total_test_batches)
271277

272278
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
273279
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
274-
if self._should_update(self.test_batch_idx, self.total_test_batches):
280+
if self._should_update(self.test_batch_idx):
281+
self._update_bar(self.test_progress_bar)
282+
283+
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
284+
if self.is_enabled:
275285
self._update_bar(self.test_progress_bar)
276286

277287
def on_test_end(self, trainer, pl_module):
278-
super().on_test_end(trainer, pl_module)
279288
self.test_progress_bar.close()
280289

281290
def on_predict_epoch_start(self, trainer, pl_module):
@@ -285,7 +294,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
285294

286295
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
287296
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
288-
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
297+
if self._should_update(self.predict_batch_idx):
289298
self._update_bar(self.predict_progress_bar)
290299

291300
def on_predict_end(self, trainer, pl_module):
@@ -309,8 +318,8 @@ def print(
309318
s = sep.join(map(str, args))
310319
active_progress_bar.write(s, end=end, file=file, nolock=nolock)
311320

312-
def _should_update(self, current, total) -> bool:
313-
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
321+
def _should_update(self, idx: int) -> bool:
322+
return self.is_enabled and (idx % self.refresh_rate == 0)
314323

315324
def _update_bar(self, bar: Optional[Tqdm]) -> None:
316325
"""Updates the bar by the refresh rate without overshooting."""

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,7 @@ def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
17921792
r"""
17931793
.. deprecated:: v1.5
17941794
This method was deprecated in v1.5 in favor of
1795-
`pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7.
1795+
`pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7.
17961796
17971797
Implement this to override the default items displayed in the progress bar.
17981798
By default it includes the average loss value, split index of BPTT (if used)

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import pickle
1616
import sys
17+
from collections import defaultdict
1718
from typing import Optional, Union
1819
from unittest import mock
1920
from unittest.mock import ANY, call, Mock
@@ -607,3 +608,65 @@ def test_tqdm_progress_bar_main_bar_resume():
607608
# restarting mid validation epoch is not currently supported
608609
assert bar.val_progress_bar.n == 0
609610
assert bar.val_progress_bar.total == 3
611+
612+
613+
def test_tqdm_progress_bar_correct_value_epoch_end(tmpdir):
614+
class MockedProgressBar(TQDMProgressBar):
615+
calls = defaultdict(list)
616+
617+
def get_metrics(self, trainer, pl_module):
618+
items = super().get_metrics(trainer, model)
619+
del items["v_num"]
620+
del items["loss"]
621+
# this is equivalent to mocking `set_postfix` as this method gets called every time
622+
self.calls[trainer.state.fn].append(
623+
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)
624+
)
625+
return items
626+
627+
class MyModel(BoringModel):
628+
def training_step(self, batch, batch_idx):
629+
self.log("a", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
630+
return super().training_step(batch, batch_idx)
631+
632+
def validation_step(self, batch, batch_idx):
633+
self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
634+
return super().validation_step(batch, batch_idx)
635+
636+
def test_step(self, batch, batch_idx):
637+
self.log("c", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
638+
return super().test_step(batch, batch_idx)
639+
640+
model = MyModel()
641+
pbar = MockedProgressBar()
642+
trainer = Trainer(
643+
default_root_dir=tmpdir,
644+
limit_train_batches=2,
645+
limit_val_batches=2,
646+
limit_test_batches=2,
647+
max_epochs=2,
648+
enable_model_summary=False,
649+
enable_checkpointing=False,
650+
log_every_n_steps=1,
651+
callbacks=pbar,
652+
)
653+
654+
trainer.fit(model)
655+
assert pbar.calls["fit"] == [
656+
("sanity_check", 0, 0, {"b": 0}),
657+
("train", 0, 0, {}),
658+
("train", 0, 1, {}),
659+
("validate", 0, 1, {"b": 1}), # validation end
660+
# epoch end over, `on_epoch=True` metrics are computed
661+
("train", 0, 2, {"a": 1, "b": 1}), # training epoch end
662+
("train", 1, 2, {"a": 1, "b": 1}),
663+
("train", 1, 3, {"a": 1, "b": 1}),
664+
("validate", 1, 3, {"a": 1, "b": 3}), # validation end
665+
("train", 1, 4, {"a": 3, "b": 3}), # training epoch end
666+
]
667+
668+
trainer.validate(model, verbose=False)
669+
assert pbar.calls["validate"] == []
670+
671+
trainer.test(model, verbose=False)
672+
assert pbar.calls["test"] == []

0 commit comments

Comments
 (0)