Skip to content

Commit 13a9112

Browse files
akashkwpre-commit-ci[bot]awaelchli
authored andcommitted
Deprecate and remove calls to agg_and_log_metrics (#11832)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent c6aae7a commit 13a9112

File tree

5 files changed

+78
-35
lines changed

5 files changed

+78
-35
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
397397
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`
398398

399399

400+
- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))
401+
402+
400403
### Removed
401404

402405
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

pytorch_lightning/loggers/base.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def _aggregate_metrics(
104104
) -> Tuple[int, Optional[Dict[str, float]]]:
105105
"""Aggregates metrics.
106106
107+
.. deprecated:: v1.6
108+
This method is deprecated in v1.6 and will be removed in v1.8.
109+
107110
Args:
108111
metrics: Dictionary with metric names as keys and measured quantities as values
109112
step: Step number at which the metrics should be recorded
@@ -126,7 +129,13 @@ def _aggregate_metrics(
126129
return agg_step, agg_mets
127130

128131
def _reduce_agg_metrics(self):
129-
"""Aggregate accumulated metrics."""
132+
"""Aggregate accumulated metrics.
133+
134+
See deprecation warning below.
135+
136+
.. deprecated:: v1.6
137+
This method is deprecated in v1.6 and will be removed in v1.8.
138+
"""
130139
# compute the metrics
131140
if not self._metrics_to_agg:
132141
agg_mets = None
@@ -137,7 +146,13 @@ def _reduce_agg_metrics(self):
137146
return self._prev_step, agg_mets
138147

139148
def _finalize_agg_metrics(self):
140-
"""This shall be called before save/close."""
149+
"""This shall be called before save/close.
150+
151+
See deprecation warning below.
152+
153+
.. deprecated:: v1.6
154+
This method is deprecated in v1.6 and will be removed in v1.8.
155+
"""
141156
agg_step, metrics_to_log = self._reduce_agg_metrics()
142157
self._metrics_to_agg = []
143158

@@ -148,6 +163,10 @@ def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = N
148163
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
149164
it aggregates them and logs only if metrics are ready to be logged.
150165
166+
.. deprecated:: v1.6
167+
This method is deprecated in v1.6 and will be removed in v1.8.
168+
Please use `LightningLoggerBase.log_metrics` instead.
169+
151170
Args:
152171
metrics: Dictionary with metric names as keys and measured quantities as values
153172
step: Step number at which the metrics should be recorded
@@ -272,11 +291,11 @@ def experiment(self) -> List[Any]:
272291

273292
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
274293
for logger in self._logger_iterable:
275-
logger.agg_and_log_metrics(metrics, step)
294+
logger.agg_and_log_metrics(metrics=metrics, step=step)
276295

277296
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
278297
for logger in self._logger_iterable:
279-
logger.log_metrics(metrics, step)
298+
logger.log_metrics(metrics=metrics, step=step)
280299

281300
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
282301
for logger in self._logger_iterable:

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from pytorch_lightning.utilities import _AcceleratorType, memory
2424
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2525
from pytorch_lightning.utilities.metrics import metrics_to_scalars
26-
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
26+
from pytorch_lightning.utilities.model_helpers import is_overridden
27+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2728

2829

2930
class LoggerConnector:
@@ -45,6 +46,7 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None)
4546
self._current_fx: Optional[str] = None
4647
self._batch_idx: Optional[int] = None
4748
self._split_idx: Optional[int] = None
49+
self._override_agg_and_log_metrics: bool = False
4850

4951
def on_trainer_init(
5052
self,
@@ -64,6 +66,15 @@ def on_trainer_init(
6466
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
6567
self.trainer.log_every_n_steps = log_every_n_steps
6668
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
69+
for logger in self.trainer.loggers:
70+
if is_overridden("agg_and_log_metrics", logger, LightningLoggerBase):
71+
self._override_agg_and_log_metrics = True
72+
rank_zero_deprecation(
73+
"`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
74+
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
75+
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
76+
)
77+
break
6778

6879
@property
6980
def should_flush_logs(self) -> bool:
@@ -114,7 +125,10 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
114125
step = self.trainer.global_step
115126

116127
# log actual metrics
117-
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
128+
if self._override_agg_and_log_metrics:
129+
self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
130+
else:
131+
self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step)
118132
self.trainer.logger.save()
119133

120134
"""

tests/deprecated_api/test_remove_1-8.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import optim
2020

2121
from pytorch_lightning import Callback, Trainer
22+
from pytorch_lightning.loggers import CSVLogger
2223
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
2324
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
2425
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
@@ -35,7 +36,7 @@
3536
from pytorch_lightning.utilities.apply_func import move_data_to_device
3637
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
3738
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
38-
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
39+
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
3940
from tests.helpers.boring_model import BoringDataModule, BoringModel
4041
from tests.helpers.runif import RunIf
4142
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
@@ -500,3 +501,34 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs):
500501
" and will be removed in v1.8"
501502
):
502503
trainer.fit(model)
504+
505+
506+
def test_v1_8_0_deprecated_agg_and_log_metrics_override(tmpdir):
507+
class AggregationOverrideLogger(CSVLogger):
508+
@rank_zero_only
509+
def agg_and_log_metrics(self, metrics, step):
510+
self.log_metrics(metrics=metrics, step=step)
511+
512+
logger = AggregationOverrideLogger(tmpdir)
513+
logger2 = CSVLogger(tmpdir)
514+
logger3 = CSVLogger(tmpdir)
515+
516+
# Test single loggers
517+
with pytest.deprecated_call(
518+
match="`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
519+
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
520+
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
521+
):
522+
Trainer(logger=logger)
523+
# Should have no deprecation warning
524+
Trainer(logger=logger2)
525+
526+
# Test multiple loggers
527+
with pytest.deprecated_call(
528+
match="`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
529+
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
530+
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
531+
):
532+
Trainer(logger=[logger, logger3])
533+
# Should have no deprecation warning
534+
Trainer(logger=[logger2, logger3])

tests/loggers/test_base.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def test_logger_collection():
4848
mock1.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)
4949
mock2.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)
5050

51-
logger.agg_and_log_metrics({"test": 2.0}, 4)
52-
mock1.agg_and_log_metrics.assert_called_once_with({"test": 2.0}, 4)
53-
mock2.agg_and_log_metrics.assert_called_once_with({"test": 2.0}, 4)
51+
logger.log_metrics(metrics={"test": 2.0}, step=4)
52+
mock1.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)
53+
mock2.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)
5454

5555
logger.finalize("success")
5656
mock1.finalize.assert_called_once()
@@ -225,31 +225,6 @@ def validation_epoch_end(self, outputs):
225225
trainer.fit(model)
226226

227227

228-
def test_with_accumulate_grad_batches():
229-
"""Checks if the logging is performed once for `accumulate_grad_batches` steps."""
230-
231-
class StoreHistoryLogger(CustomLogger):
232-
def __init__(self):
233-
super().__init__()
234-
self.history = {}
235-
236-
@rank_zero_only
237-
def log_metrics(self, metrics, step):
238-
if step not in self.history:
239-
self.history[step] = {}
240-
self.history[step].update(metrics)
241-
242-
logger = StoreHistoryLogger()
243-
244-
np.random.seed(42)
245-
for i, loss in enumerate(np.random.random(10)):
246-
logger.agg_and_log_metrics({"loss": loss}, step=int(i / 5))
247-
248-
assert logger.history == {0: {"loss": 0.5623850983416314}}
249-
logger.save()
250-
assert logger.history == {0: {"loss": 0.5623850983416314}, 1: {"loss": 0.4778883735637184}}
251-
252-
253228
def test_dummyexperiment_support_indexing():
254229
"""Test that the DummyExperiment can imitate indexing the experiment in a LoggerCollection."""
255230
experiment = DummyExperiment()

0 commit comments

Comments
 (0)