Skip to content

Commit b7f451a

Browse files
benglewislexierule
authored andcommitted
Add checkpoint artifact path prefix to MLflow logger (#20538)
* Add checkpoint artifact path prefix to MLflow logger Add a new `checkpoint_artifact_path_prefix` parameter to the MLflow logger. * Modify `src/lightning/pytorch/loggers/mlflow.py` to include the new parameter in the `MLFlowLogger` class constructor and use it in the `after_save_checkpoint` method. * Update the documentation in `docs/source-pytorch/visualize/loggers.rst` to include the new `checkpoint_artifact_path_prefix` parameter. * Add a new test in `tests/tests_pytorch/loggers/test_mlflow.py` to verify the functionality of the `checkpoint_artifact_path_prefix` parameter and ensure it is used in the artifact path. * Add CHANGELOG * Fix MLflow logger test for `checkpoint_path_prefix` * Update stale documentation --------- Co-authored-by: Luca Antiga <[email protected]> (cherry picked from commit 87108d8)
1 parent 4423f18 commit b7f451a

File tree

4 files changed

+70
-2
lines changed

4 files changed

+70
-2
lines changed

docs/source-pytorch/visualize/loggers.rst

+34
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,37 @@ Track and Visualize Experiments
5454

5555
</div>
5656
</div>
57+
58+
.. _mlflow_logger:
59+
60+
MLflow Logger
61+
-------------
62+
63+
The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts.
64+
65+
Example usage:
66+
67+
.. code-block:: python
68+
69+
import lightning as L
70+
from lightning.pytorch.loggers import MLFlowLogger
71+
72+
mlf_logger = MLFlowLogger(
73+
experiment_name="lightning_logs",
74+
tracking_uri="file:./ml-runs",
75+
checkpoint_path_prefix="my_prefix"
76+
)
77+
trainer = L.Trainer(logger=mlf_logger)
78+
79+
# Your LightningModule definition
80+
class LitModel(L.LightningModule):
81+
def training_step(self, batch, batch_idx):
82+
# example
83+
self.logger.experiment.whatever_ml_flow_supports(...)
84+
85+
def any_lightning_module_function_or_hook(self):
86+
self.logger.experiment.whatever_ml_flow_supports(...)
87+
88+
# Train your model
89+
model = LitModel()
90+
trainer.fit(model)

src/lightning/pytorch/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Changed
1212

13+
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored.
14+
1315
### Removed
1416

1517
### Fixed

src/lightning/pytorch/loggers/mlflow.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self):
9797
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
9898
which also logs every checkpoint during training.
9999
* if ``log_model == False`` (default), no checkpoint is logged.
100-
100+
checkpoint_path_prefix: A string to prefix the checkpoint artifact's path.
101101
prefix: A string to put at the beginning of metric keys.
102102
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
103103
default.
@@ -121,6 +121,7 @@ def __init__(
121121
tags: Optional[dict[str, Any]] = None,
122122
save_dir: Optional[str] = "./mlruns",
123123
log_model: Literal[True, False, "all"] = False,
124+
checkpoint_path_prefix: str = "",
124125
prefix: str = "",
125126
artifact_location: Optional[str] = None,
126127
run_id: Optional[str] = None,
@@ -147,6 +148,7 @@ def __init__(
147148
self._artifact_location = artifact_location
148149
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
149150
self._initialized = False
151+
self._checkpoint_path_prefix = checkpoint_path_prefix
150152

151153
from mlflow.tracking import MlflowClient
152154

@@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
361363
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
362364

363365
# Artifact path on mlflow
364-
artifact_path = Path(p).stem
366+
artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem
365367

366368
# Log the checkpoint
367369
self.experiment.log_artifact(self._run_id, p, artifact_path)

tests/tests_pytorch/loggers/test_mlflow.py

+30
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,33 @@ def test_set_tracking_uri(mlflow_mock):
427427
mlflow_mock.set_tracking_uri.assert_not_called()
428428
_ = logger.experiment
429429
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
430+
431+
432+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
433+
def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path):
434+
"""Test that the logger creates the folders and files in the right place with a prefix."""
435+
client = mlflow_mock.tracking.MlflowClient
436+
437+
# Get model, logger, trainer and train
438+
model = BoringModel()
439+
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix")
440+
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
441+
442+
trainer = Trainer(
443+
default_root_dir=tmp_path,
444+
logger=logger,
445+
max_epochs=2,
446+
limit_train_batches=3,
447+
limit_val_batches=3,
448+
)
449+
trainer.fit(model)
450+
451+
# Checkpoint log
452+
assert client.return_value.log_artifact.call_count == 2
453+
# Metadata and aliases log
454+
assert client.return_value.log_artifacts.call_count == 2
455+
456+
# Check that the prefix is used in the artifact path
457+
for call in client.return_value.log_artifact.call_args_list:
458+
args, _ = call
459+
assert str(args[2]).startswith("my_prefix")

0 commit comments

Comments
 (0)