Skip to content

Commit 1f7298d

Browse files
akashkwananthsubawaelchli
authored
Deprecate LoggerCollection in favor of trainer.loggers (#12147)
Co-authored-by: ananthsub <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 0db85d6 commit 1f7298d

File tree

8 files changed

+58
-21
lines changed

8 files changed

+58
-21
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
454454
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
455455

456456

457+
- Deprecated `LoggerCollection` in favor of `trainer.loggers` ([#12147](https://github.com/PyTorchLightning/pytorch-lightning/pull/12147))
458+
459+
457460
- Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978))
458461

459462

pytorch_lightning/loggers/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,21 @@ def version(self) -> Union[int, str]:
221221
class LoggerCollection(LightningLoggerBase):
222222
"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.
223223
224+
.. deprecated:: v1.6
225+
`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8.
226+
Directly pass a list of loggers to the Trainer and access the list via the `trainer.loggers` attribute.
227+
224228
Args:
225229
logger_iterable: An iterable collection of loggers
226230
"""
227231

228232
def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
229233
super().__init__()
230234
self._logger_iterable = logger_iterable
235+
rank_zero_deprecation(
236+
"`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers"
237+
" to the Trainer and access the list via the `trainer.loggers` attribute."
238+
)
231239

232240
def __getitem__(self, index: int) -> LightningLoggerBase:
233241
return list(self._logger_iterable)[index]

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2623,7 +2623,9 @@ def logger(self) -> Optional[LightningLoggerBase]:
26232623
" This behavior will change in v1.8 when LoggerCollection is removed, and"
26242624
" trainer.logger will return the first logger in trainer.loggers"
26252625
)
2626-
return LoggerCollection(self.loggers)
2626+
with warnings.catch_warnings():
2627+
warnings.simplefilter("ignore")
2628+
return LoggerCollection(self.loggers)
26272629

26282630
@logger.setter
26292631
def logger(self, logger: Optional[LightningLoggerBase]) -> None:

tests/deprecated_api/test_remove_1-8.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch import optim
2222

2323
from pytorch_lightning import Callback, Trainer
24-
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase
24+
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
2525
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2626
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
2727
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
@@ -662,6 +662,23 @@ def _get_python_cprofile_total_duration(profile):
662662
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
663663

664664

665+
def test_v1_8_0_logger_collection(tmpdir):
666+
logger1 = CSVLogger(tmpdir)
667+
logger2 = CSVLogger(tmpdir)
668+
669+
trainer1 = Trainer(logger=logger1)
670+
trainer2 = Trainer(logger=[logger1, logger2])
671+
672+
# Should have no deprecation warning
673+
trainer1.logger
674+
trainer1.loggers
675+
trainer2.loggers
676+
trainer2.logger
677+
678+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
679+
LoggerCollection([logger1, logger2])
680+
681+
665682
def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
666683
class PrecisionPluginSaveHook(PrecisionPlugin):
667684
def on_save_checkpoint(self, checkpoint):

tests/loggers/test_base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def test_logger_collection():
3434
mock1 = MagicMock()
3535
mock2 = MagicMock()
3636

37-
logger = LoggerCollection([mock1, mock2])
37+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
38+
logger = LoggerCollection([mock1, mock2])
3839

3940
assert logger[0] == mock1
4041
assert logger[1] == mock2
@@ -62,14 +63,16 @@ def test_logger_collection_unique_names():
6263
logger1 = CustomLogger(name=unique_name)
6364
logger2 = CustomLogger(name=unique_name)
6465

65-
logger = LoggerCollection([logger1, logger2])
66+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
67+
logger = LoggerCollection([logger1, logger2])
6668

6769
assert logger.name == unique_name
6870

6971

7072
def test_logger_collection_names_order():
7173
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
72-
logger = LoggerCollection(loggers)
74+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
75+
logger = LoggerCollection(loggers)
7376
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
7477

7578

@@ -78,14 +81,16 @@ def test_logger_collection_unique_versions():
7881
logger1 = CustomLogger(version=unique_version)
7982
logger2 = CustomLogger(version=unique_version)
8083

81-
logger = LoggerCollection([logger1, logger2])
84+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
85+
logger = LoggerCollection([logger1, logger2])
8286

8387
assert logger.version == unique_version
8488

8589

8690
def test_logger_collection_versions_order():
8791
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
88-
logger = LoggerCollection(loggers)
92+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
93+
logger = LoggerCollection(loggers)
8994
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
9095

9196

tests/profiler/test_profiler.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from pytorch_lightning import Callback, Trainer
2626
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
27-
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
27+
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
2828
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
2929
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -450,9 +450,9 @@ def test_pytorch_profiler_nested(tmpdir):
450450
assert events_name == expected, (events_name, torch.__version__, platform.system())
451451

452452

453-
def test_pytorch_profiler_logger_collection(tmpdir):
454-
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer's logger is an
455-
instance of LoggerCollection.
453+
def test_pytorch_profiler_multiple_loggers(tmpdir):
454+
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer is configured with
455+
multiple loggers.
456456
457457
See issue #8157.
458458
"""
@@ -465,10 +465,9 @@ def look_for_trace(trace_dir):
465465
assert not look_for_trace(tmpdir)
466466

467467
model = BoringModel()
468-
# Wrap the logger in a list so it becomes a LoggerCollection
469-
logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
470-
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
471-
assert isinstance(trainer.logger, LoggerCollection)
468+
loggers = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
469+
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=loggers, limit_train_batches=5, max_epochs=1)
470+
assert len(trainer.loggers) == 2
472471
trainer.fit(model)
473472
assert look_for_trace(tmpdir)
474473

tests/trainer/properties/test_log_dir.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pytorch_lightning import Trainer
1717
from pytorch_lightning.callbacks import ModelCheckpoint
18-
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
18+
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
1919
from tests.helpers.boring_model import BoringModel
2020

2121

@@ -109,8 +109,8 @@ def test_logdir_custom_logger(tmpdir):
109109
assert trainer.log_dir == expected
110110

111111

112-
def test_logdir_logger_collection(tmpdir):
113-
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection."""
112+
def test_logdir_multiple_loggers(tmpdir):
113+
"""Tests that the logdir equals the default_root_dir when trainer has multiple loggers."""
114114
default_root_dir = tmpdir / "default_root_dir"
115115
save_dir = tmpdir / "save_dir"
116116
model = TestModel(default_root_dir)
@@ -119,7 +119,6 @@ def test_logdir_logger_collection(tmpdir):
119119
max_steps=2,
120120
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)],
121121
)
122-
assert isinstance(trainer.logger, LoggerCollection)
123122
assert trainer.log_dir == default_root_dir
124123

125124
trainer.fit(model)

tests/trainer/properties/test_loggers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pytest
16+
1517
from pytorch_lightning import Trainer
1618
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
1719
from tests.loggers.test_base import CustomLogger
@@ -50,8 +52,10 @@ def test_trainer_loggers_setters():
5052
"""Test the behavior of setters for trainer.logger and trainer.loggers."""
5153
logger1 = CustomLogger()
5254
logger2 = CustomLogger()
53-
logger_collection = LoggerCollection([logger1, logger2])
54-
logger_collection_2 = LoggerCollection([logger2])
55+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
56+
logger_collection = LoggerCollection([logger1, logger2])
57+
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
58+
logger_collection_2 = LoggerCollection([logger2])
5559

5660
trainer = Trainer()
5761
assert type(trainer.logger) == TensorBoardLogger

0 commit comments

Comments
 (0)