Skip to content

Deprecate LightningModule.get_progress_bar_dict #8985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2f36657
Move get_progress_bar_dict from lightning module to progress bar call…
daniellepintz Aug 18, 2021
3be4395
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2021
8016914
update changelog
daniellepintz Aug 18, 2021
6e38eeb
Merge branch 'progress_bar' of https://github.com/daniellepintz/pytor…
daniellepintz Aug 18, 2021
8b8531b
update changelog
daniellepintz Aug 18, 2021
5903553
wip
daniellepintz Aug 24, 2021
471f9db
deprecate progress_bar_dict in properties.py
daniellepintz Aug 25, 2021
84750c3
wip
daniellepintz Aug 25, 2021
cd66e83
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Aug 25, 2021
c69e0a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2021
b85212d
wip
daniellepintz Aug 25, 2021
1925303
Merge branch 'progress_bar' of github.com:daniellepintz/pytorch-light…
daniellepintz Aug 25, 2021
b361575
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2021
ef6919d
Add get_progress_bar_metrics
daniellepintz Aug 26, 2021
a7d72c7
Merge branch 'progress_bar' of github.com:daniellepintz/pytorch-light…
daniellepintz Aug 26, 2021
ddc0a7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2021
a8a9c7c
update docstring
daniellepintz Aug 26, 2021
480f11c
Merge branch 'progress_bar' of github.com:daniellepintz/pytorch-light…
daniellepintz Aug 26, 2021
d6bba4f
wip
daniellepintz Aug 26, 2021
5b7180e
update
tchaton Aug 26, 2021
5f4f3c5
Move functions to progressbarbase
daniellepintz Aug 27, 2021
5b0863b
wip
daniellepintz Aug 27, 2021
263cb3f
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Aug 27, 2021
6698526
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
0df7ebd
update __about__
daniellepintz Aug 27, 2021
61c3525
Merge branch 'progress_bar' of github.com:daniellepintz/pytorch-light…
daniellepintz Aug 27, 2021
a49690d
address comments
daniellepintz Aug 28, 2021
cd57525
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Aug 28, 2021
f7cb74a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
fa380ff
address comments
daniellepintz Aug 28, 2021
2a0e779
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 7, 2021
3d96481
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
9897936
address comments
daniellepintz Sep 7, 2021
52a6045
Update for RichProgressBar
kaushikb11 Sep 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))


- Deprecated `LightningModule.get_progress_bar_dict` and `Trainer.progress_bar_dict` in favor of `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and `ProgressBarBase.get_metrics` ([#8985](https://github.com/PyTorchLightning/pytorch-lightning/pull/8985))


- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` ([#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958))


Expand Down
6 changes: 0 additions & 6 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1242,12 +1242,6 @@ backward
.. automethod:: pytorch_lightning.core.lightning.LightningModule.backward
:noindex:

get_progress_bar_dict
~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
:noindex:

on_before_backward
~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 3 additions & 3 deletions docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,13 @@ Modifying the progress bar

The progress bar by default already includes the training loss and version number of the experiment
if you are using a logger. These defaults can be customized by overriding the
:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.
:func:`~pytorch_lightning.callbacks.base.ProgressBarBase.get_metrics` hook in your module.

.. code-block:: python

def get_progress_bar_dict(self):
def get_metrics(self):
# don't show the version number
items = super().get_progress_bar_dict()
items = super().get_metrics()
items.pop("v_num", None)
return items

Expand Down
71 changes: 71 additions & 0 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_warn


class ProgressBarBase(Callback):
Expand Down Expand Up @@ -177,3 +181,70 @@ def on_predict_epoch_start(self, trainer, pl_module):

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._predict_batch_idx += 1

def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
r"""
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
Implement this to override the items displayed in the progress bar.

Here is an example of how to override the defaults:

.. code-block:: python

def get_metrics(self, trainer, model):
# don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items

Return:
Dictionary with the items to be displayed in the progress bar.
"""
standard_metrics = pl_module.get_progress_bar_dict()
pbar_metrics = trainer.progress_bar_metrics
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
" If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",
UserWarning,
)

return {**standard_metrics, **pbar_metrics}


def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
r"""
Returns several standard metrics displayed in the progress bar, including the average loss value,
split index of BPTT (if used) and the version of the experiment when using a logger.

.. code-block::

Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]

Return:
Dictionary with the standard metrics to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_train_loss = trainer.fit_loop.running_loss.mean()
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
elif pl_module.automatic_optimization:
avg_training_loss = float("NaN")

items_dict = {}
if avg_training_loss is not None:
items_dict["loss"] = f"{avg_training_loss:.3g}"

if pl_module.truncated_bptt_steps > 0:
items_dict["split_idx"] = trainer.fit_loop.split_idx

if trainer.logger is not None and trainer.logger.version is not None:
version = trainer.logger.version
# show last 4 places of long version strings
version = version[-4:] if isinstance(version, str) else version
items_dict["v_num"] = version

return items_dict
8 changes: 7 additions & 1 deletion pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def render(self, task) -> Text:
if self._trainer.training and task.id != self._current_task_id:
return self._tasks[task.id]
_text = ""
for k, v in self._trainer.progress_bar_dict.items():
# TODO(@daniellepintz): make this code cleaner
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
if progress_bar_callback:
metrics = self.progress_bar_callback.get_metrics(self.trainer, self)
else:
metrics = self._trainer.progress_bar_metrics
for k, v in metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
text = Text.from_markup(_text, style=None, justify="left")
return text
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
total_batches = convert_inf(total_batches)
if self._should_update(self.train_batch_idx, total_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
Expand All @@ -257,7 +257,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
Expand Down
28 changes: 6 additions & 22 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric

from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -1702,6 +1703,10 @@ def unfreeze(self) -> None:

def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
r"""
.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of
`pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7.

Implement this to override the default items displayed in the progress bar.
By default it includes the average loss value, split index of BPTT (if used)
and the version of the experiment when using a logger.
Expand All @@ -1723,28 +1728,7 @@ def get_progress_bar_dict(self):
Return:
Dictionary with the items to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_train_loss = self.trainer.fit_loop.running_loss.mean()
avg_training_loss = None
if running_train_loss is not None:
avg_training_loss = running_train_loss.cpu().item()
elif self.automatic_optimization:
avg_training_loss = float("NaN")

tqdm_dict = {}
if avg_training_loss is not None:
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"

if self.truncated_bptt_steps > 0:
tqdm_dict["split_idx"] = self.trainer.fit_loop.split_idx

if self.trainer.logger is not None and self.trainer.logger.version is not None:
version = self.trainer.logger.version
# show last 4 places of long version strings
version = version[-4:] if isinstance(version, str) else version
tqdm_dict["v_num"] = version

return tqdm_dict
return progress_base.get_standard_metrics(self.trainer, self)

def _verify_is_manual_optimization(self, fn_name):
if self.automatic_optimization:
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
self._check_progress_bar(model)
# TODO: Delete _check_on_keyboard_interrupt in v1.7
self._check_on_keyboard_interrupt()

Expand Down Expand Up @@ -111,6 +113,19 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None
"(rather, they are called on every optimization step)."
)

def _check_progress_bar(self, model: "pl.LightningModule") -> None:
r"""
Checks if get_progress_bar_dict is overriden and sends a deprecation warning.

Args:
model: The model to check the get_progress_bar_dict method.
"""
if is_overridden("get_progress_bar_dict", model):
rank_zero_deprecation(
"The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7."
" Please use the `ProgressBarBase.get_metrics` instead."
)

def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None:
loader_name = f"{stage}_dataloader"
step_name = "validation_step" if stage == "val" else "test_step"
Expand Down
28 changes: 8 additions & 20 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,7 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.utilities import (
DeviceType,
DistributedType,
GradClipAlgorithmType,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities import DeviceType, DistributedType, GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
Expand Down Expand Up @@ -306,21 +300,15 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]:
@property
def progress_bar_dict(self) -> dict:
"""Read-only for progress bar metrics."""
rank_zero_deprecation(
"`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7."
" Use `ProgressBarBase.get_metrics` instead."
)
ref_model = self.lightning_module
ref_model = cast(pl.LightningModule, ref_model)

standard_metrics = ref_model.get_progress_bar_dict()
pbar_metrics = self.progress_bar_metrics
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
" If this is undesired, change the name or override `get_progress_bar_dict()`"
" in `LightingModule`.",
UserWarning,
)
return {**standard_metrics, **pbar_metrics}
if self.progress_bar_callback:
return self.progress_bar_callback.get_metrics(self, ref_model)
return self.progress_bar_metrics

@property
def _should_reload_dl_epoch(self) -> bool:
Expand Down
23 changes: 23 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -557,6 +558,28 @@ def _test_progress_bar_max_val_check_interval(
assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches


def test_get_progress_bar_metrics(tmpdir: str):
class TestProgressBar(ProgressBar):
def get_metrics(self, trainer: Trainer, model: LightningModule):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items

progress_bar = TestProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[progress_bar],
fast_dev_run=True,
)
model = BoringModel()
trainer.fit(model)
model.truncated_bptt_steps = 2
standard_metrics = progress_bar.get_metrics(trainer, model)
assert "loss" in standard_metrics.keys()
assert "split_idx" in standard_metrics.keys()
assert "v_num" not in standard_metrics.keys()


def test_progress_bar_main_bar_resume():
"""Test that the progress bar can resume its counters based on the Trainer state."""
bar = ProgressBar()
Expand Down
23 changes: 23 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,29 @@ def test_v1_7_0_datamodule_dims_property(tmpdir):
_ = LightningDataModule(dims=(1, 1, 1))


def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):
class TestModel(BoringModel):
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items

trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=None,
fast_dev_run=True,
)
test_model = TestModel()
with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"):
trainer.fit(test_model)
standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix
assert "loss" in standard_metrics_postfix
assert "v_num" not in standard_metrics_postfix

with pytest.deprecated_call(match=r"`trainer.progress_bar_dict` is deprecated in v1.5"):
_ = trainer.progress_bar_dict


def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
with pytest.deprecated_call(
match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!"
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def training_epoch_end(self, outputs):
trainer = Trainer(max_epochs=num_epochs, default_root_dir=tmpdir, overfit_batches=2)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
metrics = trainer.progress_bar_dict
metrics = trainer.progress_bar_callback.get_metrics(trainer, model)

# metrics added in training step should be unchanged by epoch end method
assert metrics["step_metric"] == -1
Expand Down
Loading