Skip to content

Create comet_experiment property to replace experiment in CometLogger #11570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,14 @@ def __init__(

@property
@rank_zero_experiment
def experiment(self):
def comet_experiment(self) -> Union[CometExperiment, CometExistingExperiment, CometOfflineExperiment]:
r"""
Actual Comet object. To use Comet features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

Example::

self.logger.experiment.some_comet_function()
self.logger.comet_experiment.some_comet_function()

"""
if self._experiment is not None:
Expand Down Expand Up @@ -230,11 +230,26 @@ def experiment(self):

return self._experiment

@property
@rank_zero_experiment
def experiment(self) -> Union[CometExperiment, CometExistingExperiment, CometOfflineExperiment]:
r"""

Actual Comet object. To use Comet features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

Example::

self.logger.experiment.some_comet_function()

"""
return self.comet_experiment

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
params = self._flatten_dict(params)
self.experiment.log_parameters(params)
self.comet_experiment.log_parameters(params)

@rank_zero_only
def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None:
Expand All @@ -247,23 +262,23 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti

epoch = metrics_without_epoch.pop("epoch", None)
metrics_without_epoch = self._add_prefix(metrics_without_epoch)
self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)
self.comet_experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)

def reset_experiment(self):
self._experiment = None

@rank_zero_only
def finalize(self, status: str) -> None:
r"""
When calling ``self.experiment.end()``, that experiment won't log any more data to Comet.
When calling ``self.comet_experiment.end()``, that experiment won't log any more data to Comet.
That's why, if you need to log any more data, you need to create an ExistingCometExperiment.
For example, to log data when testing your model after training, because when training is
finalized :meth:`CometLogger.finalize` is called.

This happens automatically in the :meth:`~CometLogger.experiment` property, when
This happens automatically in the :meth:`~CometLogger.comet_experiment` property, when
``self._experiment`` is set to ``None``, i.e. ``self.reset_experiment()``.
"""
self.experiment.end()
self.comet_experiment.end()
self.reset_experiment()

@property
Expand Down
20 changes: 10 additions & 10 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def test_comet_logger_online(comet):
with patch("pytorch_lightning.loggers.comet.CometExperiment") as comet_experiment:
logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general")

_ = logger.experiment
_ = logger.comet_experiment

comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general")

# Test both given
with patch("pytorch_lightning.loggers.comet.CometExperiment") as comet_experiment:
logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general")

_ = logger.experiment
_ = logger.comet_experiment

comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general")

Expand All @@ -59,7 +59,7 @@ def test_comet_logger_online(comet):
project_name="general",
)

_ = logger.experiment
_ = logger.comet_experiment

comet_existing.assert_called_once_with(
api_key="key", workspace="dummy-test", project_name="general", previous_experiment="test"
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_comet_logger_experiment_name(comet):
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)
assert logger._experiment is None

_ = logger.experiment
_ = logger.comet_experiment
comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)
comet_experiment().set_name.assert_called_once_with(experiment_name)

Expand All @@ -120,7 +120,7 @@ def save_os_environ(*args, **kwargs):
assert logger.version == experiment_key
assert logger._experiment is None

_ = logger.experiment
_ = logger.comet_experiment
comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key
Expand All @@ -142,13 +142,13 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch
assert logger.name == "test"
assert logger.version == "4321"

_ = logger.experiment
_ = logger.comet_experiment

comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name="test")

# mock return values of experiment
logger.experiment.id = "1"
logger.experiment.project_name = "test"
logger.comet_experiment.id = "1"
logger.comet_experiment.project_name = "test"

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3)
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_comet_version_without_experiment(comet):
assert logger.version == first_version
assert logger._experiment is None

_ = logger.experiment
_ = logger.comet_experiment

logger.reset_experiment()

Expand All @@ -220,7 +220,7 @@ def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch):
_patch_comet_atexit(monkeypatch)
logger = CometLogger(project_name="test", save_dir=tmpdir)
logger.log_metrics({"test": 1, "epoch": 1}, step=123)
logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)
logger.comet_experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)


@patch("pytorch_lightning.loggers.comet.CometExperiment")
Expand Down