diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 277af5c85f539..a708387969888 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -19,23 +19,25 @@ import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, Literal, Mapping, Optional, TYPE_CHECKING, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from lightning.fabric.utilities.logger import _convert_params from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: - from comet_ml import ExistingExperiment, Experiment, OfflineExperiment + from comet_ml import ExistingExperiment, Experiment, OfflineExperiment, ExperimentConfig, BaseExperiment log = logging.getLogger(__name__) -_COMET_AVAILABLE = RequirementCache("comet-ml>=3.31.0", module="comet_ml") +_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") + +comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"] +framework = "pytorch-lightning" class CometLogger(Logger): @@ -64,7 +66,6 @@ class CometLogger(Logger): workspace=os.environ.get("COMET_WORKSPACE"), # Optional save_dir=".", # Optional project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"), # Optional experiment_name="lightning_logs", # Optional ) @@ -81,7 +82,6 @@ class CometLogger(Logger): save_dir=".", workspace=os.environ.get("COMET_WORKSPACE"), # Optional project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional experiment_name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -106,6 +106,9 @@ def __init__(self, *args, **kwarg): # log multiple parameters logger.log_hyperparams({"batch_size": 16, "learning_rate": 0.001}) + # log nested parameters + logger.log_hyperparams({"specific": {'param': {'subparam': "value"}}}) + **Log Metrics:** .. code-block:: python @@ -116,6 +119,9 @@ def __init__(self, *args, **kwarg): # add multiple metrics logger.log_metrics({"train/loss": 0.001, "val/loss": 0.002}) + # add nested metrics + logger.log_hyperparams({"specific": {'metric': {'submetric': "value"}}}) + **Access the Comet Experiment object:** You can gain access to the underlying Comet @@ -166,100 +172,73 @@ def __init__(self, *args, **kwarg): - `Comet Documentation `__ Args: - api_key: Required in online mode. API key, found on Comet.ml. If not given, this - will be loaded from the environment variable COMET_API_KEY or ~/.comet.config - if either exists. - save_dir: Required in offline mode. The path for the directory to save local - comet logs. If given, this also sets the directory for saving checkpoints. - project_name: Optional. Send your experiment to a specific project. - Otherwise will be sent to Uncategorized Experiments. - If the project name does not already exist, Comet.ml will create a new project. - rest_api_key: Optional. Rest API key found in Comet.ml settings. - This is used to determine version number - experiment_name: Optional. String representing the name for this particular experiment on Comet.ml. - experiment_key: Optional. If set, restores from existing experiment. - offline: If api_key and save_dir are both given, this determines whether - the experiment will be in online or offline mode. This is useful if you use - save_dir to control the checkpoints directory and have a ~/.comet.config - file but still want to run offline experiments. - prefix: A string to put at the beginning of metric keys. - \**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + api_key (str, optional): Comet API key. It's recommended to configure the API Key with `comet login`. + workspace (str, optional): Comet workspace name. If not provided, uses the default workspace. + project (str, optional): Comet project name. Defaults to `Uncategorized`. + experiment_key (str, optional): The Experiment identifier to be used for logging. This is used either to append + data to an Existing Experiment or to control the key of new experiments (for example to match another + identifier). Must be an alphanumeric string whose length is between 32 and 50 characters. + mode (str, optional): Control how the Comet experiment is started. + * ``"get_or_create"``: Starts a fresh experiment if required, or persists logging to an existing one. + * ``"get"``: Continue logging to an existing experiment identified by the ``experiment_key`` value. + * ``"create"``: Always creates of a new experiment, useful for HPO sweeps. + online (boolean, optional): If True, the data will be logged to Comet server, otherwise it will be stored + locally in an offline experiment. Default is ``True``. + **kwargs: Additional arguments like `experiment_name`, `log_code`, `prefix`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. Raises: ModuleNotFoundError: If required Comet package is not installed on the device. - MisconfigurationException: - If neither ``api_key`` nor ``save_dir`` are passed as arguments. + ValueError: If no API Key is set in online mode. + ExperimentNotFound: If mode="get" and the experiment_key doesn't exist, or you don't have access to it. + InvalidExperimentMode: + * If mode="get" but no experiment_key was passed or configured. + * If mode="create", an experiment_key was passed or configured and + an Experiment with that Key already exists. """ - LOGGER_JOIN_CHAR = "-" - def __init__( self, api_key: Optional[str] = None, - save_dir: Optional[str] = None, - project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, - experiment_name: Optional[str] = None, + workspace: Optional[str] = None, + project: Optional[str] = None, experiment_key: Optional[str] = None, - offline: bool = False, - prefix: str = "", + mode: Optional[Literal["get_or_create", "get", "create"]] = None, + online: Optional[bool] = None, **kwargs: Any, ): if not _COMET_AVAILABLE: raise ModuleNotFoundError(str(_COMET_AVAILABLE)) + super().__init__() - self._experiment = None - self._save_dir: Optional[str] - self.rest_api_key: Optional[str] # needs to be set before the first `comet_ml` import + # because comet_ml imported after another machine learning libraries (Torch) os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" + self._prefix = kwargs.pop("prefix", None) + import comet_ml - # Determine online or offline mode based on which arguments were passed to CometLogger - api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) - - if api_key is not None and save_dir is not None: - self.mode = "offline" if offline else "online" - self.api_key = api_key - self._save_dir = save_dir - elif api_key is not None: - self.mode = "online" - self.api_key = api_key - self._save_dir = None - elif save_dir is not None: - self.mode = "offline" - self._save_dir = save_dir - else: - # If neither api_key nor save_dir are passed as arguments, raise an exception - raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") - - log.info(f"CometLogger will be initialized in {self.mode} mode") - - self._project_name: Optional[str] = project_name - self._experiment_key: Optional[str] = experiment_key - self._experiment_name: Optional[str] = experiment_name - self._prefix: str = prefix - self._kwargs: Any = kwargs - self._future_experiment_key: Optional[str] = None - - if rest_api_key is not None: - from comet_ml.api import API - - # Comet.ml rest API, used to determine version number - self.rest_api_key = rest_api_key - self.comet_api = API(self.rest_api_key) - else: - self.rest_api_key = None - self.comet_api = None + comet_config = comet_ml.ExperimentConfig(**kwargs) + + self._experiment = comet_ml.start( + api_key=api_key, + workspace=workspace, + project=project, + experiment_key=experiment_key, + mode=mode, + online=online, + experiment_config=comet_config, + ) + + self._experiment.log_other("Created from", "pytorch-lightning") @property @rank_zero_experiment - def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment"]: + def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment", "BaseExperiment"]: r"""Actual Comet object. To use Comet features in your :class:`~lightning.pytorch.core.LightningModule` do the following. @@ -268,82 +247,48 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi self.logger.experiment.some_comet_function() """ - if self._experiment is not None and self._experiment.alive: - return self._experiment - - if self._future_experiment_key is not None: - os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key - - from comet_ml import ExistingExperiment, Experiment, OfflineExperiment - - try: - if self.mode == "online": - if self._experiment_key is None: - self._experiment = Experiment(api_key=self.api_key, project_name=self._project_name, **self._kwargs) - self._experiment_key = self._experiment.get_key() - else: - self._experiment = ExistingExperiment( - api_key=self.api_key, - project_name=self._project_name, - previous_experiment=self._experiment_key, - **self._kwargs, - ) - else: - self._experiment = OfflineExperiment( - offline_directory=self.save_dir, project_name=self._project_name, **self._kwargs - ) - self._experiment.log_other("Created from", "pytorch-lightning") - finally: - if self._future_experiment_key is not None: - os.environ.pop("COMET_EXPERIMENT_KEY") - self._future_experiment_key = None - - if self._experiment_name: - self._experiment.set_name(self._experiment_name) - return self._experiment @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) - params = _flatten_dict(params) - self.experiment.log_parameters(params) + self.experiment.__internal_api__log_parameters__( + parameters=params, + framework=framework, + ) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - # Comet.ml expects metrics to be a dictionary of detached tensors on CPU + # Comet.com expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() for key, val in metrics_without_epoch.items(): if isinstance(val, Tensor): metrics_without_epoch[key] = val.cpu().detach() epoch = metrics_without_epoch.pop("epoch", None) - metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) - self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) - - def reset_experiment(self) -> None: - self._experiment = None + self.experiment.__internal_api__log_metrics__( + metrics_without_epoch, + step=step, + epoch=epoch, + prefix=self._prefix, + framework=framework, + ) @override @rank_zero_only def finalize(self, status: str) -> None: - r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you - need to log any more data, you need to create an ExistingCometExperiment. For example, to log data when testing - your model after training, because when training is finalized :meth:`CometLogger.finalize` is called. - - This happens automatically in the :meth:`~CometLogger.experiment` property, when - ``self._experiment`` is set to ``None``, i.e. ``self.reset_experiment()``. - - """ + """We will not end experiment (self._experiment.end()) here to have an ability to continue using it after + training is complete but instead of ending we will upload/save all the data.""" if self._experiment is None: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been # initialized there return - self.experiment.end() - self.reset_experiment() + + # just save the data + self.experiment.flush() @property @override @@ -354,7 +299,11 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self._save_dir + import comet_ml + if isinstance(self._experiment, comet_ml.OfflineExperiment): + return self._experiment.offline_directory + + return None @property @override @@ -362,17 +311,10 @@ def name(self) -> str: """Gets the project name. Returns: - The project name if it is specified, else "comet-default". + The project name. """ - # Don't create an experiment if we don't have one - if self._experiment is not None and self._experiment.project_name is not None: - return self._experiment.project_name - - if self._project_name is not None: - return self._project_name - - return "comet-default" + return self._experiment.project_name @property @override @@ -380,35 +322,10 @@ def version(self) -> str: """Gets the version. Returns: - The first one of the following that is set in the following order - - 1. experiment id. - 2. experiment key. - 3. "COMET_EXPERIMENT_KEY" environment variable. - 4. future experiment key. - - If none are present generates a new guid. + experiment key """ - # Don't create an experiment if we don't have one - if self._experiment is not None: - return self._experiment.id - - if self._experiment_key is not None: - return self._experiment_key - - if "COMET_EXPERIMENT_KEY" in os.environ: - return os.environ["COMET_EXPERIMENT_KEY"] - - if self._future_experiment_key is not None: - return self._future_experiment_key - - import comet_ml - - # Pre-generate an experiment key - self._future_experiment_key = comet_ml.generate_guid() - - return self._future_experiment_key + return self._experiment.get_key() def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -416,7 +333,7 @@ def __getstate__(self) -> Dict[str, Any]: # Save the experiment id in case an experiment object already exists, # this way we could create an ExistingExperiment pointing to the same # experiment - state["_experiment_key"] = self._experiment.id if self._experiment is not None else None + state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None # Remove the experiment object as it contains hard to pickle objects # (like network connections), the experiment object will be recreated if @@ -426,5 +343,7 @@ def __getstate__(self) -> Dict[str, Any]: @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: - if self._experiment is not None: - self._experiment.set_model_graph(model) + self._experiment.__internal_api__set_model_graph__( + graph=model, + framework=framework, + ) diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index e467c63543ede..224509faabdf1 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -60,11 +60,6 @@ def test_comet_logger_online(comet_mock): ) comet_existing().set_name.assert_called_once_with("experiment") - # API experiment - api = comet_mock.api.API - CometLogger(api_key="key", workspace="dummy-test", project_name="general", rest_api_key="rest") - api.assert_called_once_with("rest") - @mock.patch.dict(os.environ, {}) def test_comet_experiment_resets_if_not_alive(comet_mock):