Skip to content

Remove the deprecated agg_and_log_metrics #14840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated device attributes `Trainer.{devices,gpus,num_gpus,ipus,tpu_cores}` in favor of the accelerator-agnostic `Trainer.num_devices` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))


- Removed the deprecated `Logger.agg_and_log_metrics` hook in favour of `Logger.log_metrics` and the `agg_key_funcs` and `agg_default_func` arguments. ([#14840](https://github.com/Lightning-AI/lightning/pull/14840))


- Removed the deprecated precision plugin checkpoint hooks `PrecisionPlugin.on_load_checkpoint` and `PrecisionPlugin.on_save_checkpoint` ([#14833](https://github.com/Lightning-AI/lightning/pull/14833))


Expand Down
20 changes: 1 addition & 19 deletions src/pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,7 @@ def rank_zero_experiment(fn: Callable) -> Callable:


class LightningLoggerBase(logger.Logger):
"""Base class for experiment loggers.

Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.

.. deprecated:: v1.6
The parameters `agg_key_funcs` and `agg_default_func` are deprecated
in v1.6 and will be removed in v1.8.

Note:
The `agg_key_funcs` and `agg_default_func` arguments are used only when
one logs metrics with the :meth:`~LightningLoggerBase.agg_and_log_metrics` method.
"""
"""Base class for experiment loggers."""

def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
rank_zero_deprecation(
Expand Down
6 changes: 2 additions & 4 deletions src/pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
from argparse import Namespace
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
from typing import Any, Dict, Mapping, Optional, Union

from lightning_utilities.core.imports import module_available
from torch import Tensor
Expand Down Expand Up @@ -219,15 +219,13 @@ def __init__(
experiment_key: Optional[str] = None,
offline: bool = False,
prefix: str = "",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs: Any,
):
if comet_ml is None:
raise ModuleNotFoundError(
"You want to use `comet_ml` logger which is not installed yet, install it with `pip install comet-ml`."
)
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
super().__init__()
self._experiment = None
self._save_dir: Optional[str]
self.rest_api_key: Optional[str]
Expand Down
94 changes: 6 additions & 88 deletions src/pytorch_lightning/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from argparse import Namespace
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
from weakref import ReferenceType

import numpy as np
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only


def rank_zero_experiment(fn: Callable) -> Callable:
Expand Down Expand Up @@ -57,47 +57,7 @@ def get_experiment() -> Callable:


class Logger(ABC):
"""Base class for experiment loggers.

Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.

.. deprecated:: v1.6
The parameters `agg_key_funcs` and `agg_default_func` are deprecated
in v1.6 and will be removed in v1.8.

Note:
The `agg_key_funcs` and `agg_default_func` arguments are used only when
one logs metrics with the :meth:`~Logger.agg_and_log_metrics` method.
"""

def __init__(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
):
self._prev_step: int = -1
self._metrics_to_agg: List[Dict[str, float]] = []
if agg_key_funcs:
self._agg_key_funcs = agg_key_funcs
rank_zero_deprecation(
"The `agg_key_funcs` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
)
else:
self._agg_key_funcs = {}
if agg_default_func:
self._agg_default_func = agg_default_func
rank_zero_deprecation(
"The `agg_default_func` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
)
else:
self._agg_default_func = np.mean
"""Base class for experiment loggers."""

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
"""Called after model checkpoint callback saves a new checkpoint.
Expand All @@ -107,52 +67,9 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]"
"""
pass

def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean,
) -> None:
"""Update aggregation methods.

.. deprecated:: v1.6
`update_agg_funcs` is deprecated in v1.6 and will be removed in v1.8.

Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.
"""
if agg_key_funcs:
self._agg_key_funcs.update(agg_key_funcs)
if agg_default_func:
self._agg_default_func = agg_default_func
rank_zero_deprecation("`Logger.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8.")

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
Please use `Logger.log_metrics` instead.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
self.log_metrics(metrics=metrics, step=step)

@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""
Records metrics.
This method logs metrics as as soon as it received them. If you want to aggregate
metrics for one specific `step`, use the
:meth:`~pytorch_lightning.loggers.base.Logger.agg_and_log_metrics` method.
"""Records metrics. This method logs metrics as soon as it received them.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
Expand Down Expand Up @@ -273,7 +190,8 @@ def method(*args: Any, **kwargs: Any) -> None:
return method


def merge_dicts(
# TODO: this should have been deprecated
def merge_dicts( # pragma: no cover
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping] = None,
default_func: Callable[[Sequence[float]], float] = np.mean,
Expand Down
6 changes: 2 additions & 4 deletions src/pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
import os
from argparse import Namespace
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Sequence, Set, Union
from typing import Any, Dict, Generator, List, Optional, Set, Union
from weakref import ReferenceType

from lightning_utilities.core.imports import RequirementCache
Expand Down Expand Up @@ -227,15 +227,13 @@ def __init__(
run: Optional["Run"] = None,
log_model_checkpoints: Optional[bool] = True,
prefix: str = "training",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**neptune_run_kwargs: Any,
):
if not _NEPTUNE_AVAILABLE:
raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE))
# verify if user passed proper init arguments
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
super().__init__()
self._log_model_checkpoints = log_model_checkpoints
self._prefix = prefix
self._run_name = name
Expand Down
6 changes: 2 additions & 4 deletions src/pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
from argparse import Namespace
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
from typing import Any, Dict, Mapping, Optional, Union

import numpy as np
from torch import Tensor
Expand Down Expand Up @@ -94,11 +94,9 @@ def __init__(
default_hp_metric: bool = True,
prefix: str = "",
sub_dir: Optional[str] = None,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs: Any,
):
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
super().__init__()
self._save_dir = save_dir
self._name = name or ""
self._version = version
Expand Down
6 changes: 2 additions & 4 deletions src/pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Union
from weakref import ReferenceType

import torch.nn as nn
Expand Down Expand Up @@ -294,8 +294,6 @@ def __init__(
log_model: Union[str, bool] = False,
experiment: Union[Run, RunDisabled, None] = None,
prefix: str = "",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs: Any,
) -> None:
if wandb is None:
Expand All @@ -318,7 +316,7 @@ def __init__(
"Hint: Upgrade with `pip install --upgrade wandb`."
)

super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
super().__init__()
self._offline = offline
self._log_model = log_model
self._prefix = prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from pytorch_lightning.loggers import Logger, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class LoggerConnector:
Expand All @@ -36,7 +34,6 @@ def __init__(self, trainer: "pl.Trainer") -> None:
self._current_fx: Optional[str] = None
self._batch_idx: Optional[int] = None
self._split_idx: Optional[int] = None
self._override_agg_and_log_metrics: bool = False

def on_trainer_init(
self,
Expand All @@ -47,15 +44,6 @@ def on_trainer_init(
self.configure_logger(logger)
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
for logger in self.trainer.loggers:
if is_overridden("agg_and_log_metrics", logger, Logger):
self._override_agg_and_log_metrics = True
rank_zero_deprecation(
"`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom"
" loggers should not implement `Logger.agg_and_log_metrics`."
)
break

@property
def should_update_logs(self) -> bool:
Expand Down Expand Up @@ -104,10 +92,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:

# log actual metrics
for logger in self.trainer.loggers:
if self._override_agg_and_log_metrics:
logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
logger.log_metrics(metrics=scalar_metrics, step=step)
logger.log_metrics(metrics=scalar_metrics, step=step)
logger.save()

"""
Expand Down
68 changes: 0 additions & 68 deletions tests/tests_pytorch/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
from pytorch_lightning.loggers import CSVLogger, Logger
from pytorch_lightning.profilers import AdvancedProfiler, SimpleProfiler
from pytorch_lightning.strategies.ipu import LightningIPUModule
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.rank_zero import rank_zero_only


def test_v1_8_0_on_init_start_end(tmpdir):
Expand Down Expand Up @@ -289,72 +287,6 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs):
trainer.fit(model)


def test_v1_8_0_logger_agg_parameters():
class CustomLogger(Logger):
@rank_zero_only
def log_hyperparams(self, params):
pass

@rank_zero_only
def log_metrics(self, metrics, step):
pass

@property
def name(self):
pass

@property
def version(self):
pass

with pytest.deprecated_call(
match="The `agg_key_funcs` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
):
CustomLogger(agg_key_funcs={"mean", np.mean})

with pytest.deprecated_call(
match="The `agg_default_func` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
):
CustomLogger(agg_default_func=np.mean)

# Should have no deprecation warning
logger = CustomLogger()

with pytest.deprecated_call(match="`Logger.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."):
logger.update_agg_funcs()


def test_v1_8_0_deprecated_agg_and_log_metrics_override(tmpdir):
class AggregationOverrideLogger(CSVLogger):
@rank_zero_only
def agg_and_log_metrics(self, metrics, step):
self.log_metrics(metrics=metrics, step=step)

logger = AggregationOverrideLogger(tmpdir)
logger2 = CSVLogger(tmpdir)
logger3 = CSVLogger(tmpdir)

# Test single loggers
with pytest.deprecated_call(
match="`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom"
" loggers should not implement `Logger.agg_and_log_metrics`."
):
Trainer(logger=logger)
# Should have no deprecation warning
Trainer(logger=logger2)

# Test multiple loggers
with pytest.deprecated_call(
match="`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom"
" loggers should not implement `Logger.agg_and_log_metrics`."
):
Trainer(logger=[logger, logger3])
# Should have no deprecation warning
Trainer(logger=[logger2, logger3])


def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir):
class TestCallback(Callback):
def on_pretrain_routine_start(self, trainer, pl_module):
Expand Down