Skip to content

Commit 2b9f0ae

Browse files
authored
Lazily import dependencies for NeptuneLogger (#18573)
1 parent b727f3f commit 2b9f0ae

File tree

7 files changed

+197
-255
lines changed

7 files changed

+197
-255
lines changed

src/lightning/pytorch/loggers/neptune.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@
1515
Neptune Logger
1616
--------------
1717
"""
18-
__all__ = [
19-
"NeptuneLogger",
20-
]
21-
2218
import contextlib
2319
import logging
2420
import os
2521
from argparse import Namespace
26-
from typing import Any, Dict, Generator, List, Optional, Set, Union
22+
from typing import Any, Dict, Generator, List, Optional, Set, TYPE_CHECKING, Union
2723

2824
from lightning_utilities.core.imports import RequirementCache
2925
from torch import Tensor
@@ -35,30 +31,15 @@
3531
from lightning.pytorch.utilities.model_summary import ModelSummary
3632
from lightning.pytorch.utilities.rank_zero import rank_zero_only
3733

38-
# neptune is available with two names on PyPI : `neptune` and `neptune-client`
39-
_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
40-
_NEPTUNE_CLIENT_AVAILABLE = RequirementCache("neptune-client")
41-
42-
if _NEPTUNE_AVAILABLE:
43-
# >1.0 package structure
44-
import neptune
34+
if TYPE_CHECKING:
4535
from neptune import Run
4636
from neptune.handler import Handler
47-
from neptune.types import File
48-
from neptune.utils import stringify_unsupported
49-
elif _NEPTUNE_CLIENT_AVAILABLE:
50-
# <1.0 package structure
51-
import neptune.new as neptune
52-
from neptune.new import Run
53-
from neptune.new.handler import Handler
54-
from neptune.new.types import File
55-
from neptune.new.utils import stringify_unsupported
56-
else:
57-
# needed for tests, mocks and function signatures
58-
neptune, Run, Handler, File, stringify_unsupported = None, None, None, None, None
5937

6038
log = logging.getLogger(__name__)
6139

40+
# neptune is available with two names on PyPI : `neptune` and `neptune-client`
41+
_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
42+
_NEPTUNE_CLIENT_AVAILABLE = RequirementCache("neptune-client")
6243
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"
6344

6445

@@ -258,6 +239,11 @@ def __init__(
258239
if self._run_instance is not None:
259240
self._retrieve_run_data()
260241

242+
if _NEPTUNE_AVAILABLE:
243+
from neptune.handler import Handler
244+
else:
245+
from neptune.new.handler import Handler
246+
261247
# make sure that we've log integration version for outside `Run` instances
262248
root_obj = self._run_instance
263249
if isinstance(root_obj, Handler):
@@ -266,6 +252,11 @@ def __init__(
266252
root_obj[_INTEGRATION_VERSION_KEY] = pl.__version__
267253

268254
def _retrieve_run_data(self) -> None:
255+
if _NEPTUNE_AVAILABLE:
256+
from neptune.handler import Handler
257+
else:
258+
from neptune.new.handler import Handler
259+
269260
assert self._run_instance is not None
270261
root_obj = self._run_instance
271262
if isinstance(root_obj, Handler):
@@ -317,6 +308,12 @@ def _verify_input_arguments(
317308
run: Optional[Union["Run", "Handler"]],
318309
neptune_run_kwargs: dict,
319310
) -> None:
311+
if _NEPTUNE_AVAILABLE:
312+
from neptune import Run
313+
from neptune.handler import Handler
314+
else:
315+
from neptune.new import Run
316+
from neptune.new.handler import Handler
320317
# check if user passed the client `Run`/`Handler` object
321318
if run is not None and not isinstance(run, (Run, Handler)):
322319
raise ValueError("Run parameter expected to be of type `neptune.Run`, or `neptune.handler.Handler`.")
@@ -325,8 +322,8 @@ def _verify_input_arguments(
325322
any_neptune_init_arg_passed = any(arg is not None for arg in [api_key, project, name]) or neptune_run_kwargs
326323
if run is not None and any_neptune_init_arg_passed:
327324
raise ValueError(
328-
"When an already initialized run object is provided"
329-
" you can't provide other neptune.init_run() parameters.\n"
325+
"When an already initialized run object is provided, you can't provide other `neptune.init_run()`"
326+
" parameters."
330327
)
331328

332329
def __getstate__(self) -> Dict[str, Any]:
@@ -336,12 +333,17 @@ def __getstate__(self) -> Dict[str, Any]:
336333
return state
337334

338335
def __setstate__(self, state: Dict[str, Any]) -> None:
336+
if _NEPTUNE_AVAILABLE:
337+
import neptune
338+
else:
339+
import neptune.new as neptune
340+
339341
self.__dict__ = state
340342
self._run_instance = neptune.init_run(**self._neptune_init_args)
341343

342344
@property
343345
@rank_zero_experiment
344-
def experiment(self) -> Run:
346+
def experiment(self) -> "Run":
345347
r"""
346348
Actual Neptune run object. Allows you to use neptune logging features in your
347349
:class:`~lightning.pytorch.core.module.LightningModule`.
@@ -371,7 +373,12 @@ def training_step(self, batch, batch_idx):
371373

372374
@property
373375
@rank_zero_experiment
374-
def run(self) -> Run:
376+
def run(self) -> "Run":
377+
if _NEPTUNE_AVAILABLE:
378+
import neptune
379+
else:
380+
import neptune.new as neptune
381+
375382
if not self._run_instance:
376383
self._run_instance = neptune.init_run(**self._neptune_init_args)
377384
self._retrieve_run_data()
@@ -416,6 +423,11 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: #
416423
neptune_logger.log_hyperparams(PARAMS)
417424
418425
"""
426+
if _NEPTUNE_AVAILABLE:
427+
from neptune.utils import stringify_unsupported
428+
else:
429+
from neptune.new.utils import stringify_unsupported
430+
419431
params = _convert_params(params)
420432
params = _sanitize_callable_params(params)
421433

@@ -468,8 +480,13 @@ def save_dir(self) -> Optional[str]:
468480

469481
@rank_zero_only
470482
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
483+
if _NEPTUNE_AVAILABLE:
484+
from neptune.types import File
485+
else:
486+
from neptune.new.types import File
487+
471488
model_str = str(ModelSummary(model=model, max_depth=max_depth))
472-
self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
489+
self.run[self._construct_path_with_prefix("model/summary")] = File.from_content(
473490
content=model_str, extension="txt"
474491
)
475492

@@ -484,6 +501,11 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
484501
if not self._log_model_checkpoints:
485502
return
486503

504+
if _NEPTUNE_AVAILABLE:
505+
from neptune.types import File
506+
else:
507+
from neptune.new.types import File
508+
487509
file_names = set()
488510
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
489511

tests/tests_pytorch/loggers/conftest.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import sys
1515
from types import ModuleType
16-
from unittest.mock import Mock
16+
from unittest.mock import MagicMock, Mock
1717

1818
import pytest
1919

@@ -37,6 +37,8 @@ def mlflow_mock(monkeypatch):
3737

3838
mlflow.tracking = mlflow_tracking
3939
mlflow.entities = mlflow_entities
40+
41+
monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True),
4042
return mlflow
4143

4244

@@ -73,6 +75,8 @@ class RunType: # to make isinstance checks pass
7375
wandb.sdk = wandb_sdk
7476
wandb.sdk.lib = wandb_sdk_lib
7577
wandb.wandb_run = wandb_wandb_run
78+
79+
monkeypatch.setattr("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
7680
return wandb
7781

7882

@@ -92,4 +96,46 @@ def comet_mock(monkeypatch):
9296
monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api)
9397

9498
comet.api = comet_api
99+
100+
monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True)
95101
return comet
102+
103+
104+
@pytest.fixture()
105+
def neptune_mock(monkeypatch):
106+
class RunType: # to make isinstance checks pass
107+
def get_root_object(self):
108+
pass
109+
110+
def __getitem__(self, item):
111+
pass
112+
113+
def __setitem__(self, key, value):
114+
pass
115+
116+
run_mock = MagicMock(spec=RunType, exists=Mock(return_value=False), wait=Mock(), get_structure=MagicMock())
117+
run_mock.get_root_object.return_value = run_mock
118+
119+
neptune = ModuleType("neptune")
120+
neptune.init_run = Mock(return_value=run_mock)
121+
neptune.Run = RunType
122+
monkeypatch.setitem(sys.modules, "neptune", neptune)
123+
124+
neptune_handler = ModuleType("handler")
125+
neptune_handler.Handler = RunType
126+
monkeypatch.setitem(sys.modules, "neptune.handler", neptune_handler)
127+
128+
neptune_types = ModuleType("types")
129+
neptune_types.File = Mock()
130+
monkeypatch.setitem(sys.modules, "neptune.types", neptune_types)
131+
132+
neptune_utils = ModuleType("utils")
133+
neptune_utils.stringify_unsupported = Mock()
134+
monkeypatch.setitem(sys.modules, "neptune.utils", neptune_utils)
135+
136+
neptune.handler = neptune_handler
137+
neptune.types = neptune_types
138+
neptune.utils = neptune_utils
139+
140+
monkeypatch.setattr("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", True)
141+
return neptune

tests/tests_pytorch/loggers/test_all.py

Lines changed: 35 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import contextlib
1514
import inspect
1615
import os
1716
import pickle
@@ -37,17 +36,7 @@
3736
from tests_pytorch.helpers.runif import RunIf
3837
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
3938
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
40-
from tests_pytorch.loggers.test_neptune import create_neptune_mock
41-
42-
LOGGER_CTX_MANAGERS = (
43-
mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
44-
mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()),
45-
mock.patch("lightning.pytorch.loggers.neptune.neptune", new_callable=create_neptune_mock),
46-
mock.patch("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
47-
mock.patch("lightning.pytorch.loggers.neptune.Run", new=mock.Mock),
48-
mock.patch("lightning.pytorch.loggers.neptune.Handler", new=mock.Mock),
49-
mock.patch("lightning.pytorch.loggers.neptune.File", new=mock.Mock()),
50-
)
39+
5140
ALL_LOGGER_CLASSES = (
5241
CometLogger,
5342
CSVLogger,
@@ -79,18 +68,11 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
7968

8069

8170
@mock.patch.dict(os.environ, {})
82-
@mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
83-
@mock.patch("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True)
71+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
8472
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
85-
def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, tmp_path):
73+
def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path):
8674
"""Verify that basic functionality of all loggers."""
87-
with contextlib.ExitStack() as stack:
88-
for mgr in LOGGER_CTX_MANAGERS:
89-
stack.enter_context(mgr)
90-
_test_loggers_fit_test(tmp_path, logger_class)
91-
9275

93-
def _test_loggers_fit_test(tmp_path, logger_class):
9476
class CustomModel(BoringModel):
9577
def training_step(self, batch, batch_idx):
9678
loss = self.step(batch)
@@ -300,38 +282,32 @@ def _test_logger_initialization(tmp_path, logger_class):
300282

301283

302284
@mock.patch.dict(os.environ, {})
303-
def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, monkeypatch, tmp_path):
285+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
286+
def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_mock, monkeypatch, tmp_path):
304287
"""Test that prefix is added at the beginning of the metric keys."""
305288
prefix = "tmp"
306289

307290
# Comet
308-
with mock.patch("lightning.pytorch.loggers.comet._COMET_AVAILABLE", return_value=True):
309-
_patch_comet_atexit(monkeypatch)
310-
logger = _instantiate_logger(CometLogger, save_dir=tmp_path, prefix=prefix)
311-
logger.log_metrics({"test": 1.0}, step=0)
312-
logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0)
291+
_patch_comet_atexit(monkeypatch)
292+
logger = _instantiate_logger(CometLogger, save_dir=tmp_path, prefix=prefix)
293+
logger.log_metrics({"test": 1.0}, step=0)
294+
logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0)
313295

314296
# MLflow
315-
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
316-
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
317-
):
318-
Metric = mlflow_mock.entities.Metric
319-
logger = _instantiate_logger(MLFlowLogger, save_dir=tmp_path, prefix=prefix)
320-
logger.log_metrics({"test": 1.0}, step=0)
321-
logger.experiment.log_batch.assert_called_once_with(
322-
run_id=ANY, metrics=[Metric(key="tmp-test", value=1.0, timestamp=ANY, step=0)]
323-
)
297+
Metric = mlflow_mock.entities.Metric
298+
logger = _instantiate_logger(MLFlowLogger, save_dir=tmp_path, prefix=prefix)
299+
logger.log_metrics({"test": 1.0}, step=0)
300+
logger.experiment.log_batch.assert_called_once_with(
301+
run_id=ANY, metrics=[Metric(key="tmp-test", value=1.0, timestamp=ANY, step=0)]
302+
)
324303

325304
# Neptune
326-
with mock.patch("lightning.pytorch.loggers.neptune.neptune"), mock.patch(
327-
"lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True
328-
), mock.patch("lightning.pytorch.loggers.neptune.Handler", new=mock.Mock):
329-
logger = _instantiate_logger(NeptuneLogger, api_key="test", project="project", save_dir=tmp_path, prefix=prefix)
330-
assert logger.experiment.__getitem__.call_count == 0
331-
logger.log_metrics({"test": 1.0}, step=0)
332-
assert logger.experiment.__getitem__.call_count == 1
333-
logger.experiment.__getitem__.assert_called_with("tmp/test")
334-
logger.experiment.__getitem__().append.assert_called_once_with(1.0)
305+
logger = _instantiate_logger(NeptuneLogger, api_key="test", project="project", save_dir=tmp_path, prefix=prefix)
306+
assert logger.experiment.__getitem__.call_count == 0
307+
logger.log_metrics({"test": 1.0}, step=0)
308+
assert logger.experiment.__getitem__.call_count == 1
309+
logger.experiment.__getitem__.assert_called_with("tmp/test")
310+
logger.experiment.__getitem__().append.assert_called_once_with(1.0)
335311

336312
# TensorBoard
337313
if _TENSORBOARD_AVAILABLE:
@@ -345,14 +321,14 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, monkeypatch
345321
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
346322

347323
# WandB
348-
with mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True):
349-
logger = _instantiate_logger(WandbLogger, save_dir=tmp_path, prefix=prefix)
350-
wandb_mock.run = None
351-
wandb_mock.init().step = 0
352-
logger.log_metrics({"test": 1.0}, step=0)
353-
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
324+
logger = _instantiate_logger(WandbLogger, save_dir=tmp_path, prefix=prefix)
325+
wandb_mock.run = None
326+
wandb_mock.init().step = 0
327+
logger.log_metrics({"test": 1.0}, step=0)
328+
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
354329

355330

331+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
356332
def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path):
357333
"""Test that the default logger name is lightning_logs."""
358334
# CSV
@@ -370,14 +346,11 @@ def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path):
370346
assert logger.name == "lightning_logs"
371347

372348
# MLflow
373-
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
374-
"lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()
375-
):
376-
client = mlflow_mock.tracking.MlflowClient()
377-
client.get_experiment_by_name.return_value = None
378-
logger = _instantiate_logger(MLFlowLogger, save_dir=tmp_path)
379-
380-
_ = logger.experiment
381-
logger._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY)
382-
# on MLFLowLogger `name` refers to the experiment id
383-
# assert logger.experiment.get_experiment(logger.name).name == "lightning_logs"
349+
client = mlflow_mock.tracking.MlflowClient()
350+
client.get_experiment_by_name.return_value = None
351+
logger = _instantiate_logger(MLFlowLogger, save_dir=tmp_path)
352+
353+
_ = logger.experiment
354+
logger._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY)
355+
# on MLFLowLogger `name` refers to the experiment id
356+
# assert logger.experiment.get_experiment(logger.name).name == "lightning_logs"

0 commit comments

Comments
 (0)