diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index b544212e755e2..ac49c846dccbe 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -29,6 +29,7 @@ from lightning.fabric.utilities.logger import _convert_params from lightning.fabric.utilities.rank_zero import _get_rank +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.utilities.rank_zero import rank_zero_only @@ -187,6 +188,8 @@ def __init__(self, *args, **kwarg): locally in an offline experiment. Default is ``True``. prefix: The prefix to add to names of the logged metrics. example: prefix=`exp1`, then metric name will be logged as `exp1_metric_name` + flush_every: Controls whether the Comet experiment flushes logs to the Comet server after each checkpoint. + If no value is provided, flushing will not occur. **kwargs: Additional arguments like `name`, `log_code`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. @@ -206,6 +209,7 @@ def __init__( mode: Optional[Literal["get_or_create", "get", "create"]] = None, online: Optional[bool] = None, prefix: Optional[str] = None, + flush_every: Optional[Literal["checkpoint"]] = None, **kwargs: Any, ): if not _COMET_AVAILABLE: @@ -263,6 +267,7 @@ def __init__( self._experiment_key: Optional[str] = experiment_key self._prefix: Optional[str] = prefix self._kwargs: dict[str, Any] = kwargs + self._flush_every: Optional[Literal["checkpoint"]] = flush_every # needs to be set before the first `comet_ml` import # because comet_ml imported after another machine learning libraries (Torch) @@ -419,3 +424,8 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None graph=model, framework=FRAMEWORK_NAME, ) + + @override + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + if self._experiment is not None and self._flush_every == "checkpoint": + self._experiment.flush() diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index dae8f617b873e..36028f3f2771b 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -199,3 +199,27 @@ def test_comet_metrics_safe(comet_mock, tmp_path, monkeypatch): metrics = {"tensor": tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), "epoch": 1} logger.log_metrics(metrics) assert metrics["tensor"].requires_grad + + +@mock.patch.dict(os.environ, {}) +def test_comet_flush_every_checkpoint(comet_mock): + """Test that the CometLogger is flushing Comet experiment after each checkpoint.""" + + logger = CometLogger(flush_every="checkpoint") + assert logger._experiment is not None + + logger.after_save_checkpoint(Mock()) + + logger._experiment.flush.assert_called_once() + + +@mock.patch.dict(os.environ, {}) +def test_comet_flush_every_not_called(comet_mock): + """Test that the CometLogger don't call Comet experiment flush if not requested after each checkpoint.""" + + logger = CometLogger() + assert logger._experiment is not None + + logger.after_save_checkpoint(Mock()) + + logger._experiment.flush.assert_not_called()