Skip to content

Commit 3c4d06b

Browse files
authored
Update the TQDM progress bar on_train_epoch_end (#11069)
1 parent ffb1a75 commit 3c4d06b

File tree

4 files changed

+94
-22
lines changed

4 files changed

+94
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
277277
- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))
278278

279279

280-
-
280+
- The TQDM progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069))
281281

282282

283-
-
283+
- Fixed bug where the TQDM updated the training progress bar during `trainer.validate` ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069))
284284

285285

286286
## [1.5.5] - 2021-12-07

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
else:
2626
from tqdm import tqdm as _tqdm
2727

28+
import pytorch_lightning as pl
2829
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
2930
from pytorch_lightning.utilities.distributed import rank_zero_debug
3031

@@ -207,12 +208,10 @@ def init_test_tqdm(self) -> Tqdm:
207208
return bar
208209

209210
def on_sanity_check_start(self, trainer, pl_module):
210-
super().on_sanity_check_start(trainer, pl_module)
211211
self.val_progress_bar = self.init_sanity_tqdm()
212212
self.main_progress_bar = Tqdm(disable=True) # dummy progress bar
213213

214214
def on_sanity_check_end(self, trainer, pl_module):
215-
super().on_sanity_check_end(trainer, pl_module)
216215
self.main_progress_bar.close()
217216
self.val_progress_bar.close()
218217

@@ -234,49 +233,59 @@ def on_train_epoch_start(self, trainer, pl_module):
234233

235234
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
236235
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
237-
total_batches = self.total_train_batches + self.total_val_batches
238-
total_batches = convert_inf(total_batches)
239-
if self._should_update(self.train_batch_idx, total_batches):
236+
if self._should_update(self.train_batch_idx):
240237
self._update_bar(self.main_progress_bar)
241238
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
242239

240+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
241+
if self.is_enabled:
242+
self._update_bar(self.main_progress_bar)
243+
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
244+
245+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
246+
self.main_progress_bar.close()
247+
243248
def on_validation_start(self, trainer, pl_module):
244249
super().on_validation_start(trainer, pl_module)
245250
if trainer.sanity_checking:
246251
reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx)
247252
else:
248-
self._update_bar(self.main_progress_bar) # fill up remaining
253+
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
254+
self._update_bar(self.main_progress_bar) # fill up remaining
249255
self.val_progress_bar = self.init_validation_tqdm()
250256
reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx)
251257

252258
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
253259
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
254-
if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)):
260+
if self._should_update(self.val_batch_idx):
261+
self._update_bar(self.val_progress_bar)
262+
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
263+
self._update_bar(self.main_progress_bar)
264+
265+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
266+
if self.is_enabled:
255267
self._update_bar(self.val_progress_bar)
256-
self._update_bar(self.main_progress_bar)
257268

258269
def on_validation_end(self, trainer, pl_module):
259-
super().on_validation_end(trainer, pl_module)
260-
if self.main_progress_bar is not None:
270+
if self.main_progress_bar is not None and trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
261271
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
262272
self.val_progress_bar.close()
263273

264-
def on_train_end(self, trainer, pl_module):
265-
super().on_train_end(trainer, pl_module)
266-
self.main_progress_bar.close()
267-
268274
def on_test_start(self, trainer, pl_module):
269275
super().on_test_start(trainer, pl_module)
270276
self.test_progress_bar = self.init_test_tqdm()
271277
self.test_progress_bar.total = convert_inf(self.total_test_batches)
272278

273279
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
274280
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
275-
if self._should_update(self.test_batch_idx, self.total_test_batches):
281+
if self._should_update(self.test_batch_idx):
282+
self._update_bar(self.test_progress_bar)
283+
284+
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
285+
if self.is_enabled:
276286
self._update_bar(self.test_progress_bar)
277287

278288
def on_test_end(self, trainer, pl_module):
279-
super().on_test_end(trainer, pl_module)
280289
self.test_progress_bar.close()
281290

282291
def on_predict_epoch_start(self, trainer, pl_module):
@@ -286,7 +295,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
286295

287296
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
288297
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
289-
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
298+
if self._should_update(self.predict_batch_idx):
290299
self._update_bar(self.predict_progress_bar)
291300

292301
def on_predict_end(self, trainer, pl_module):
@@ -310,8 +319,8 @@ def print(
310319
s = sep.join(map(str, args))
311320
active_progress_bar.write(s, end=end, file=file, nolock=nolock)
312321

313-
def _should_update(self, current, total) -> bool:
314-
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
322+
def _should_update(self, idx: int) -> bool:
323+
return self.is_enabled and (idx % self.refresh_rate == 0)
315324

316325
def _update_bar(self, bar: Optional[Tqdm]) -> None:
317326
"""Updates the bar by the refresh rate without overshooting."""

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1710,7 +1710,7 @@ def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
17101710
r"""
17111711
.. deprecated:: v1.5
17121712
This method was deprecated in v1.5 in favor of
1713-
`pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7.
1713+
`pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7.
17141714
17151715
Implement this to override the default items displayed in the progress bar.
17161716
By default it includes the average loss value, split index of BPTT (if used)

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import pickle
1616
import sys
17+
from collections import defaultdict
1718
from typing import Union
1819
from unittest import mock
1920
from unittest.mock import ANY, call, Mock
@@ -595,3 +596,65 @@ def test_tqdm_progress_bar_main_bar_resume():
595596
# restarting mid validation epoch is not currently supported
596597
assert bar.val_progress_bar.n == 0
597598
assert bar.val_progress_bar.total == 3
599+
600+
601+
def test_tqdm_progress_bar_correct_value_epoch_end(tmpdir):
602+
class MockedProgressBar(TQDMProgressBar):
603+
calls = defaultdict(list)
604+
605+
def get_metrics(self, trainer, pl_module):
606+
items = super().get_metrics(trainer, model)
607+
del items["v_num"]
608+
del items["loss"]
609+
# this is equivalent to mocking `set_postfix` as this method gets called every time
610+
self.calls[trainer.state.fn].append(
611+
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)
612+
)
613+
return items
614+
615+
class MyModel(BoringModel):
616+
def training_step(self, batch, batch_idx):
617+
self.log("a", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
618+
return super().training_step(batch, batch_idx)
619+
620+
def validation_step(self, batch, batch_idx):
621+
self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
622+
return super().validation_step(batch, batch_idx)
623+
624+
def test_step(self, batch, batch_idx):
625+
self.log("c", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
626+
return super().test_step(batch, batch_idx)
627+
628+
model = MyModel()
629+
pbar = MockedProgressBar()
630+
trainer = Trainer(
631+
default_root_dir=tmpdir,
632+
limit_train_batches=2,
633+
limit_val_batches=2,
634+
limit_test_batches=2,
635+
max_epochs=2,
636+
enable_model_summary=False,
637+
enable_checkpointing=False,
638+
log_every_n_steps=1,
639+
callbacks=pbar,
640+
)
641+
642+
trainer.fit(model)
643+
assert pbar.calls["fit"] == [
644+
("sanity_check", 0, 0, {"b": 0}),
645+
("train", 0, 0, {}),
646+
("train", 0, 1, {}),
647+
("validate", 0, 1, {"b": 1}), # validation end
648+
# epoch end over, `on_epoch=True` metrics are computed
649+
("train", 0, 2, {"a": 1, "b": 1}), # training epoch end
650+
("train", 1, 2, {"a": 1, "b": 1}),
651+
("train", 1, 3, {"a": 1, "b": 1}),
652+
("validate", 1, 3, {"a": 1, "b": 3}), # validation end
653+
("train", 1, 4, {"a": 3, "b": 3}), # training epoch end
654+
]
655+
656+
trainer.validate(model, verbose=False)
657+
assert pbar.calls["validate"] == []
658+
659+
trainer.test(model, verbose=False)
660+
assert pbar.calls["test"] == []

0 commit comments

Comments
 (0)