Skip to content

Commit 0b9034b

Browse files
twslrohitgr7
andauthored
Return only unique names/versions for LoggerCollection (#10976)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 576a5d6 commit 0b9034b

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
112112
- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))
113113

114114

115+
- `LoggerCollection` returns only unique logger names and versions ([#10976](https://github.com/PyTorchLightning/pytorch-lightning/pull/10976))
116+
117+
115118
- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
116119
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
117120
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts

pytorch_lightning/loggers/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,15 @@ def save_dir(self) -> Optional[str]:
452452

453453
@property
454454
def name(self) -> str:
455-
"""Returns the experiment names for all the loggers in the logger collection joined by an underscore."""
456-
return "_".join(str(logger.name) for logger in self._logger_iterable)
455+
"""Returns the unique experiment names for all the loggers in the logger collection joined by an
456+
underscore."""
457+
return "_".join(dict.fromkeys(str(logger.name) for logger in self._logger_iterable))
457458

458459
@property
459460
def version(self) -> str:
460-
"""Returns the experiment versions for all the loggers in the logger collection joined by an underscore."""
461-
return "_".join(str(logger.version) for logger in self._logger_iterable)
461+
"""Returns the unique experiment versions for all the loggers in the logger collection joined by an
462+
underscore."""
463+
return "_".join(dict.fromkeys(str(logger.version) for logger in self._logger_iterable))
462464

463465

464466
class DummyExperiment:

tests/loggers/test_base.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,52 @@ def test_logger_collection():
5656
mock2.finalize.assert_called_once()
5757

5858

59+
def test_logger_collection_unique_names():
60+
unique_name = "name1"
61+
logger1 = CustomLogger(name=unique_name)
62+
logger2 = CustomLogger(name=unique_name)
63+
64+
logger = LoggerCollection([logger1, logger2])
65+
66+
assert logger.name == unique_name
67+
68+
69+
def test_logger_collection_names_order():
70+
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
71+
logger = LoggerCollection(loggers)
72+
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
73+
74+
75+
def test_logger_collection_unique_versions():
76+
unique_version = "1"
77+
logger1 = CustomLogger(version=unique_version)
78+
logger2 = CustomLogger(version=unique_version)
79+
80+
logger = LoggerCollection([logger1, logger2])
81+
82+
assert logger.version == unique_version
83+
84+
85+
def test_logger_collection_versions_order():
86+
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
87+
logger = LoggerCollection(loggers)
88+
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
89+
90+
5991
class CustomLogger(LightningLoggerBase):
60-
def __init__(self):
92+
def __init__(self, experiment: str = "test", name: str = "name", version: str = "1"):
6193
super().__init__()
94+
self._experiment = experiment
95+
self._name = name
96+
self._version = version
6297
self.hparams_logged = None
6398
self.metrics_logged = {}
6499
self.finalized = False
65100
self.after_save_checkpoint_called = False
66101

67102
@property
68103
def experiment(self):
69-
return "test"
104+
return self._experiment
70105

71106
@rank_zero_only
72107
def log_hyperparams(self, params):
@@ -88,11 +123,11 @@ def save_dir(self) -> Optional[str]:
88123

89124
@property
90125
def name(self):
91-
return "name"
126+
return self._name
92127

93128
@property
94129
def version(self):
95-
return "1"
130+
return self._version
96131

97132
def after_save_checkpoint(self, checkpoint_callback):
98133
self.after_save_checkpoint_called = True

0 commit comments

Comments
 (0)