Skip to content

Commit a2d8c4f

Browse files
akashkwananthsubcarmoccadaniellepintz
authored
Create loggers property for Trainer and LightningModule (#11683)
Co-authored-by: ananthsub <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Danielle Pintz <[email protected]>
1 parent 1e36cff commit a2d8c4f

File tree

9 files changed

+183
-10
lines changed

9 files changed

+183
-10
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8686

8787
- Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))
8888

89+
90+
- Added a `loggers` property to `Trainer` which returns a list of loggers provided by the user ([#11683](https://github.com/PyTorchLightning/pytorch-lightning/pull/11683))
91+
92+
93+
- Added a `loggers` property to `LightningModule` which retrieves the `loggers` property from `Trainer` ([#11683](https://github.com/PyTorchLightning/pytorch-lightning/pull/11683))
94+
95+
8996
- Added support for DDP when using a `CombinedLoader` for the training data ([#11648](https://github.com/PyTorchLightning/pytorch-lightning/pull/11648))
9097

9198

docs/source/common/lightning_module.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,19 @@ The current logger being used (tensorboard or other supported logger)
985985
# the particular logger
986986
tensorboard_logger = self.logger.experiment
987987
988+
loggers
989+
~~~~~~~
990+
991+
The list of loggers currently being used by the Trainer.
992+
993+
.. code-block:: python
994+
995+
def training_step(self, batch, batch_idx):
996+
# List of LightningLoggerBase objects
997+
loggers = self.loggers
998+
for logger in loggers:
999+
logger.log_metrics({"foo": 1.0})
1000+
9881001
local_rank
9891002
~~~~~~~~~~~
9901003

docs/source/common/trainer.rst

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,16 +1734,28 @@ The current epoch
17341734
pass
17351735
17361736
1737-
logger (p)
1738-
**********
1737+
logger
1738+
*******
17391739

17401740
The current logger being used. Here's an example using tensorboard
17411741

17421742
.. code-block:: python
17431743
1744-
def training_step(self, batch, batch_idx):
1745-
logger = self.trainer.logger
1746-
tensorboard = logger.experiment
1744+
logger = trainer.logger
1745+
tensorboard = logger.experiment
1746+
1747+
1748+
loggers
1749+
********
1750+
1751+
The list of loggers currently being used by the Trainer.
1752+
1753+
.. code-block:: python
1754+
1755+
# List of LightningLoggerBase objects
1756+
loggers = trainer.loggers
1757+
for logger in loggers:
1758+
logger.log_metrics({"foo": 1.0})
17471759
17481760
17491761
logged_metrics

pytorch_lightning/core/lightning.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
3838
from pytorch_lightning.core.optimizer import LightningOptimizer
3939
from pytorch_lightning.core.saving import ModelIO
40+
from pytorch_lightning.loggers import LightningLoggerBase
4041
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4142
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
4243
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
@@ -75,6 +76,7 @@ class LightningModule(
7576
"global_rank",
7677
"local_rank",
7778
"logger",
79+
"loggers",
7880
"model_size",
7981
"automatic_optimization",
8082
"truncated_bptt_steps",
@@ -247,10 +249,15 @@ def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
247249
self._truncated_bptt_steps = truncated_bptt_steps
248250

249251
@property
250-
def logger(self):
252+
def logger(self) -> Optional[LightningLoggerBase]:
251253
"""Reference to the logger object in the Trainer."""
252254
return self.trainer.logger if self.trainer else None
253255

256+
@property
257+
def loggers(self) -> List[LightningLoggerBase]:
258+
"""Reference to the loggers object in the Trainer."""
259+
return self.trainer.loggers if self.trainer else []
260+
254261
def _apply_batch_transfer_handler(
255262
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
256263
) -> Any:

pytorch_lightning/trainer/trainer.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def __init__(
565565
self.__init_profiler(profiler)
566566

567567
# init logger flags
568-
self.logger: Optional[LightningLoggerBase]
568+
self._loggers: List[LightningLoggerBase]
569569
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)
570570

571571
# init debugging flags
@@ -2553,6 +2553,37 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
25532553
Logging properties
25542554
"""
25552555

2556+
@property
2557+
def logger(self) -> Optional[LightningLoggerBase]:
2558+
if len(self.loggers) == 0:
2559+
return None
2560+
if len(self.loggers) == 1:
2561+
return self.loggers[0]
2562+
else:
2563+
rank_zero_warn(
2564+
"Using trainer.logger when Trainer is configured to use multiple loggers."
2565+
" This behavior will change in v1.8 when LoggerCollection is removed, and"
2566+
" trainer.logger will return the first logger in trainer.loggers"
2567+
)
2568+
return LoggerCollection(self.loggers)
2569+
2570+
@logger.setter
2571+
def logger(self, logger: Optional[LightningLoggerBase]) -> None:
2572+
if not logger:
2573+
self.loggers = []
2574+
elif isinstance(logger, LoggerCollection):
2575+
self.loggers = list(logger)
2576+
else:
2577+
self.loggers = [logger]
2578+
2579+
@property
2580+
def loggers(self) -> List[LightningLoggerBase]:
2581+
return self._loggers
2582+
2583+
@loggers.setter
2584+
def loggers(self, loggers: Optional[List[LightningLoggerBase]]) -> None:
2585+
self._loggers = loggers if loggers else []
2586+
25562587
@property
25572588
def callback_metrics(self) -> dict:
25582589
return self.logger_connector.callback_metrics

tests/core/test_lightning_module.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ def test_property_logger(tmpdir):
7676
assert model.logger == logger
7777

7878

79+
def test_property_loggers(tmpdir):
80+
"""Test that loggers in LightningModule is accessible via the Trainer."""
81+
model = BoringModel()
82+
assert model.loggers == []
83+
84+
logger = TensorBoardLogger(tmpdir)
85+
trainer = Trainer(logger=logger)
86+
model.trainer = trainer
87+
assert model.loggers == [logger]
88+
89+
7990
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):
8091
class TestModel(BoringModel):
8192
def training_step(self, batch, batch_idx, optimizer_idx=None):

tests/profiler/test_profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from pytorch_lightning import Callback, Trainer
2626
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
27-
from pytorch_lightning.loggers.base import LoggerCollection
27+
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
2828
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
2929
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
3030
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
@@ -493,7 +493,7 @@ def look_for_trace(trace_dir):
493493

494494
model = BoringModel()
495495
# Wrap the logger in a list so it becomes a LoggerCollection
496-
logger = [TensorBoardLogger(save_dir=tmpdir)]
496+
logger = [TensorBoardLogger(save_dir=tmpdir), DummyLogger()]
497497
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
498498
assert isinstance(trainer.logger, LoggerCollection)
499499
trainer.fit(model)

tests/trainer/properties/test_log_dir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning import Trainer
1717
from pytorch_lightning.callbacks import ModelCheckpoint
1818
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
19+
from pytorch_lightning.loggers.base import DummyLogger
1920
from tests.helpers.boring_model import BoringModel
2021

2122

@@ -117,7 +118,7 @@ def test_logdir_logger_collection(tmpdir):
117118
trainer = Trainer(
118119
default_root_dir=default_root_dir,
119120
max_steps=2,
120-
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs")],
121+
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), DummyLogger()],
121122
)
122123
assert isinstance(trainer.logger, LoggerCollection)
123124
assert trainer.log_dir == default_root_dir
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pytorch_lightning import Trainer
16+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
17+
from tests.loggers.test_base import CustomLogger
18+
19+
20+
def test_trainer_loggers_property():
21+
"""Test for correct initialization of loggers in Trainer."""
22+
logger1 = CustomLogger()
23+
logger2 = CustomLogger()
24+
25+
# trainer.loggers should be a copy of the input list
26+
trainer = Trainer(logger=[logger1, logger2])
27+
28+
assert trainer.loggers == [logger1, logger2]
29+
30+
# trainer.loggers should create a list of size 1
31+
trainer = Trainer(logger=logger1)
32+
33+
assert trainer.loggers == [logger1]
34+
35+
# trainer.loggers should be an empty list
36+
trainer = Trainer(logger=False)
37+
38+
assert trainer.loggers == []
39+
40+
# trainer.loggers should be a list of size 1 holding the default logger
41+
trainer = Trainer(logger=True)
42+
43+
assert trainer.loggers == [trainer.logger]
44+
assert type(trainer.loggers[0]) == TensorBoardLogger
45+
46+
47+
def test_trainer_loggers_setters():
48+
"""Test the behavior of setters for trainer.logger and trainer.loggers."""
49+
logger1 = CustomLogger()
50+
logger2 = CustomLogger()
51+
logger_collection = LoggerCollection([logger1, logger2])
52+
logger_collection_2 = LoggerCollection([logger2])
53+
54+
trainer = Trainer()
55+
assert type(trainer.logger) == TensorBoardLogger
56+
assert trainer.loggers == [trainer.logger]
57+
58+
# Test setters for trainer.logger
59+
trainer.logger = logger1
60+
assert trainer.logger == logger1
61+
assert trainer.loggers == [logger1]
62+
63+
trainer.logger = logger_collection
64+
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
65+
assert trainer.loggers == [logger1, logger2]
66+
67+
# LoggerCollection of size 1 should result in trainer.logger becoming the contained logger.
68+
trainer.logger = logger_collection_2
69+
assert trainer.logger == logger2
70+
assert trainer.loggers == [logger2]
71+
72+
trainer.logger = None
73+
assert trainer.logger is None
74+
assert trainer.loggers == []
75+
76+
# Test setters for trainer.loggers
77+
trainer.loggers = [logger1, logger2]
78+
assert trainer.loggers == [logger1, logger2]
79+
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
80+
81+
trainer.loggers = [logger1]
82+
assert trainer.loggers == [logger1]
83+
assert trainer.logger == logger1
84+
85+
trainer.loggers = []
86+
assert trainer.loggers == []
87+
assert trainer.logger is None
88+
89+
trainer.loggers = None
90+
assert trainer.loggers == []
91+
assert trainer.logger is None

0 commit comments

Comments
 (0)