Skip to content

Lazily import dependencies for MLFlowLogger #18528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 12, 2023
Merged
7 changes: 2 additions & 5 deletions src/lightning/pytorch/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from lightning.pytorch.loggers.neptune import NeptuneLogger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.wandb import WandbLogger

__all__ = ["CSVLogger", "Logger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"]
__all__ = ["CSVLogger", "Logger", "MLFlowLogger", "TensorBoardLogger", "WandbLogger", "NeptuneLogger"]

if _COMET_AVAILABLE:
__all__.append("CometLogger")
# needed to prevent ModuleNotFoundError and duplicated logs.
os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"

if _MLFLOW_AVAILABLE:
__all__.append("MLFlowLogger")
65 changes: 30 additions & 35 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from argparse import Namespace
from pathlib import Path
from time import time
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Union

import yaml
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -37,39 +37,6 @@
log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
if _MLFLOW_AVAILABLE:
import mlflow
from mlflow.entities import Metric, Param
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
else:
mlflow = None
MlflowClient, context = None, None
Metric, Param = None, None
MLFLOW_RUN_NAME = "mlflow.runName"

# before v1.1.0
if hasattr(context, "resolve_tags"):
from mlflow.tracking.context import resolve_tags


# since v1.1.0
elif hasattr(context, "registry"):
from mlflow.tracking.context.registry import resolve_tags
else:

def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]:
"""
Args:
tags: A dictionary of tags to override. If specified, tags passed in this argument will
override those inferred from the context.

Returns: A dictionary of resolved tags.

Note:
See ``mlflow.tracking.context.registry`` for more details.
"""
return tags


class MLFlowLogger(Logger):
Expand Down Expand Up @@ -169,11 +136,13 @@ def __init__(

self._initialized = False

from mlflow.tracking import MlflowClient

self._mlflow_client = MlflowClient(tracking_uri)

@property
@rank_zero_experiment
def experiment(self) -> MlflowClient:
def experiment(self) -> Any:
r"""
Actual MLflow object. To use MLflow features in your
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
Expand All @@ -183,6 +152,8 @@ def experiment(self) -> MlflowClient:
self.logger.experiment.some_mlflow_function()

"""
import mlflow

if self._initialized:
return self._mlflow_client

Expand All @@ -207,11 +178,16 @@ def experiment(self) -> MlflowClient:
if self._run_id is None:
if self._run_name is not None:
self.tags = self.tags or {}

from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME

if MLFLOW_RUN_NAME in self.tags:
log.warning(
f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}."
)
self.tags[MLFLOW_RUN_NAME] = self._run_name

resolve_tags = _get_resolve_tags()
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
self._run_id = run.info.run_id
self._initialized = True
Expand Down Expand Up @@ -244,6 +220,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: #
params = _convert_params(params)
params = _flatten_dict(params)

from mlflow.entities import Param

# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
Expand All @@ -256,6 +234,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: #
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

from mlflow.entities import Metric

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
metrics_list: List[Metric] = []

Expand Down Expand Up @@ -383,3 +363,18 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non

# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t


def _get_resolve_tags() -> Callable:
from mlflow.tracking import context

# before v1.1.0
if hasattr(context, "resolve_tags"):
from mlflow.tracking.context import resolve_tags
# since v1.1.0
elif hasattr(context, "registry"):
from mlflow.tracking.context.registry import resolve_tags
else:
resolve_tags = lambda tags: tags

return resolve_tags
27 changes: 27 additions & 0 deletions tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys
from types import ModuleType
from unittest.mock import Mock

import pytest


@pytest.fixture()
def mlflow_mock(monkeypatch):
mlflow = ModuleType("mlflow")
mlflow.set_tracking_uri = Mock()
monkeypatch.setitem(sys.modules, "mlflow", mlflow)

mlflow_tracking = ModuleType("tracking")
mlflow_tracking.MlflowClient = Mock()
mlflow_tracking.artifact_utils = Mock()
monkeypatch.setitem(sys.modules, "mlflow.tracking", mlflow_tracking)

mlflow_entities = ModuleType("entities")
mlflow_entities.Metric = Mock()
mlflow_entities.Param = Mock()
mlflow_entities.time = Mock()
monkeypatch.setitem(sys.modules, "mlflow.entities", mlflow_entities)

mlflow.tracking = mlflow_tracking
mlflow.entities = mlflow_entities
return mlflow
22 changes: 10 additions & 12 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
mock.patch("lightning.pytorch.loggers.comet.comet_ml"),
mock.patch("lightning.pytorch.loggers.comet.CometOfflineExperiment"),
mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"),
mock.patch("lightning.pytorch.loggers.mlflow.Metric"),
mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()),
mock.patch("lightning.pytorch.loggers.neptune.neptune", new_callable=create_neptune_mock),
mock.patch("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
mock.patch("lightning.pytorch.loggers.neptune.Run", new=mock.Mock),
Expand Down Expand Up @@ -82,10 +81,9 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
return logger_class(**args)


@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
def test_loggers_fit_test_all(tmpdir, monkeypatch, logger_class):
def test_loggers_fit_test_all(logger_class, mlflow_mock, tmpdir):
"""Verify that basic functionality of all loggers."""
with contextlib.ExitStack() as stack:
for mgr in LOGGER_CTX_MANAGERS:
Expand Down Expand Up @@ -300,7 +298,7 @@ def _test_logger_initialization(tmpdir, logger_class):
trainer.fit(model)


def test_logger_with_prefix_all(tmpdir, monkeypatch):
def test_logger_with_prefix_all(mlflow_mock, monkeypatch, tmpdir):
"""Test that prefix is added at the beginning of the metric keys."""
prefix = "tmp"

Expand All @@ -315,10 +313,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):

# MLflow
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"lightning.pytorch.loggers.mlflow.Metric"
) as Metric, mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"), mock.patch(
"lightning.pytorch.loggers.mlflow.mlflow"
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
):
Metric = mlflow_mock.entities.Metric
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log_batch.assert_called_once_with(
Expand Down Expand Up @@ -358,7 +355,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})


def test_logger_default_name(tmpdir, monkeypatch):
def test_logger_default_name(mlflow_mock, monkeypatch, tmpdir):
"""Test that the default logger name is lightning_logs."""
# CSV
logger = CSVLogger(save_dir=tmpdir)
Expand All @@ -376,9 +373,10 @@ def test_logger_default_name(tmpdir, monkeypatch):

# MLflow
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"lightning.pytorch.loggers.mlflow.MlflowClient"
) as mlflow_client, mock.patch("lightning.pytorch.loggers.mlflow.mlflow"):
mlflow_client().get_experiment_by_name.return_value = None
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
):
client = mlflow_mock.tracking.MlflowClient()
client.get_experiment_by_name.return_value = None
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir)

_ = logger.experiment
Expand Down
Loading