11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import contextlib
15
14
import inspect
16
15
import os
17
16
import pickle
37
36
from tests_pytorch .helpers .runif import RunIf
38
37
from tests_pytorch .loggers .test_comet import _patch_comet_atexit
39
38
from tests_pytorch .loggers .test_mlflow import mock_mlflow_run_creation
40
- from tests_pytorch .loggers .test_neptune import create_neptune_mock
41
-
42
- LOGGER_CTX_MANAGERS = (
43
- mock .patch ("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True ),
44
- mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ()),
45
- mock .patch ("lightning.pytorch.loggers.neptune.neptune" , new_callable = create_neptune_mock ),
46
- mock .patch ("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE" , return_value = True ),
47
- mock .patch ("lightning.pytorch.loggers.neptune.Run" , new = mock .Mock ),
48
- mock .patch ("lightning.pytorch.loggers.neptune.Handler" , new = mock .Mock ),
49
- mock .patch ("lightning.pytorch.loggers.neptune.File" , new = mock .Mock ()),
50
- )
39
+
51
40
ALL_LOGGER_CLASSES = (
52
41
CometLogger ,
53
42
CSVLogger ,
@@ -79,18 +68,11 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
79
68
80
69
81
70
@mock .patch .dict (os .environ , {})
82
- @mock .patch ("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE" , True )
83
- @mock .patch ("lightning.pytorch.loggers.comet._COMET_AVAILABLE" , True )
71
+ @mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ())
84
72
@pytest .mark .parametrize ("logger_class" , ALL_LOGGER_CLASSES )
85
- def test_loggers_fit_test_all (logger_class , mlflow_mock , wandb_mock , comet_mock , tmp_path ):
73
+ def test_loggers_fit_test_all (logger_class , mlflow_mock , wandb_mock , comet_mock , neptune_mock , tmp_path ):
86
74
"""Verify that basic functionality of all loggers."""
87
- with contextlib .ExitStack () as stack :
88
- for mgr in LOGGER_CTX_MANAGERS :
89
- stack .enter_context (mgr )
90
- _test_loggers_fit_test (tmp_path , logger_class )
91
-
92
75
93
- def _test_loggers_fit_test (tmp_path , logger_class ):
94
76
class CustomModel (BoringModel ):
95
77
def training_step (self , batch , batch_idx ):
96
78
loss = self .step (batch )
@@ -300,38 +282,32 @@ def _test_logger_initialization(tmp_path, logger_class):
300
282
301
283
302
284
@mock .patch .dict (os .environ , {})
303
- def test_logger_with_prefix_all (mlflow_mock , wandb_mock , comet_mock , monkeypatch , tmp_path ):
285
+ @mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ())
286
+ def test_logger_with_prefix_all (mlflow_mock , wandb_mock , comet_mock , neptune_mock , monkeypatch , tmp_path ):
304
287
"""Test that prefix is added at the beginning of the metric keys."""
305
288
prefix = "tmp"
306
289
307
290
# Comet
308
- with mock .patch ("lightning.pytorch.loggers.comet._COMET_AVAILABLE" , return_value = True ):
309
- _patch_comet_atexit (monkeypatch )
310
- logger = _instantiate_logger (CometLogger , save_dir = tmp_path , prefix = prefix )
311
- logger .log_metrics ({"test" : 1.0 }, step = 0 )
312
- logger .experiment .log_metrics .assert_called_once_with ({"tmp-test" : 1.0 }, epoch = None , step = 0 )
291
+ _patch_comet_atexit (monkeypatch )
292
+ logger = _instantiate_logger (CometLogger , save_dir = tmp_path , prefix = prefix )
293
+ logger .log_metrics ({"test" : 1.0 }, step = 0 )
294
+ logger .experiment .log_metrics .assert_called_once_with ({"tmp-test" : 1.0 }, epoch = None , step = 0 )
313
295
314
296
# MLflow
315
- with mock .patch ("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True ), mock .patch (
316
- "lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ()
317
- ):
318
- Metric = mlflow_mock .entities .Metric
319
- logger = _instantiate_logger (MLFlowLogger , save_dir = tmp_path , prefix = prefix )
320
- logger .log_metrics ({"test" : 1.0 }, step = 0 )
321
- logger .experiment .log_batch .assert_called_once_with (
322
- run_id = ANY , metrics = [Metric (key = "tmp-test" , value = 1.0 , timestamp = ANY , step = 0 )]
323
- )
297
+ Metric = mlflow_mock .entities .Metric
298
+ logger = _instantiate_logger (MLFlowLogger , save_dir = tmp_path , prefix = prefix )
299
+ logger .log_metrics ({"test" : 1.0 }, step = 0 )
300
+ logger .experiment .log_batch .assert_called_once_with (
301
+ run_id = ANY , metrics = [Metric (key = "tmp-test" , value = 1.0 , timestamp = ANY , step = 0 )]
302
+ )
324
303
325
304
# Neptune
326
- with mock .patch ("lightning.pytorch.loggers.neptune.neptune" ), mock .patch (
327
- "lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE" , return_value = True
328
- ), mock .patch ("lightning.pytorch.loggers.neptune.Handler" , new = mock .Mock ):
329
- logger = _instantiate_logger (NeptuneLogger , api_key = "test" , project = "project" , save_dir = tmp_path , prefix = prefix )
330
- assert logger .experiment .__getitem__ .call_count == 0
331
- logger .log_metrics ({"test" : 1.0 }, step = 0 )
332
- assert logger .experiment .__getitem__ .call_count == 1
333
- logger .experiment .__getitem__ .assert_called_with ("tmp/test" )
334
- logger .experiment .__getitem__ ().append .assert_called_once_with (1.0 )
305
+ logger = _instantiate_logger (NeptuneLogger , api_key = "test" , project = "project" , save_dir = tmp_path , prefix = prefix )
306
+ assert logger .experiment .__getitem__ .call_count == 0
307
+ logger .log_metrics ({"test" : 1.0 }, step = 0 )
308
+ assert logger .experiment .__getitem__ .call_count == 1
309
+ logger .experiment .__getitem__ .assert_called_with ("tmp/test" )
310
+ logger .experiment .__getitem__ ().append .assert_called_once_with (1.0 )
335
311
336
312
# TensorBoard
337
313
if _TENSORBOARD_AVAILABLE :
@@ -345,14 +321,14 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, monkeypatch
345
321
logger .experiment .add_scalar .assert_called_once_with ("tmp-test" , 1.0 , 0 )
346
322
347
323
# WandB
348
- with mock .patch ("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE" , True ):
349
- logger = _instantiate_logger (WandbLogger , save_dir = tmp_path , prefix = prefix )
350
- wandb_mock .run = None
351
- wandb_mock .init ().step = 0
352
- logger .log_metrics ({"test" : 1.0 }, step = 0 )
353
- logger .experiment .log .assert_called_once_with ({"tmp-test" : 1.0 , "trainer/global_step" : 0 })
324
+ logger = _instantiate_logger (WandbLogger , save_dir = tmp_path , prefix = prefix )
325
+ wandb_mock .run = None
326
+ wandb_mock .init ().step = 0
327
+ logger .log_metrics ({"test" : 1.0 }, step = 0 )
328
+ logger .experiment .log .assert_called_once_with ({"tmp-test" : 1.0 , "trainer/global_step" : 0 })
354
329
355
330
331
+ @mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ())
356
332
def test_logger_default_name (mlflow_mock , monkeypatch , tmp_path ):
357
333
"""Test that the default logger name is lightning_logs."""
358
334
# CSV
@@ -370,14 +346,11 @@ def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path):
370
346
assert logger .name == "lightning_logs"
371
347
372
348
# MLflow
373
- with mock .patch ("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True ), mock .patch (
374
- "lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ()
375
- ):
376
- client = mlflow_mock .tracking .MlflowClient ()
377
- client .get_experiment_by_name .return_value = None
378
- logger = _instantiate_logger (MLFlowLogger , save_dir = tmp_path )
379
-
380
- _ = logger .experiment
381
- logger ._mlflow_client .create_experiment .assert_called_with (name = "lightning_logs" , artifact_location = ANY )
382
- # on MLFLowLogger `name` refers to the experiment id
383
- # assert logger.experiment.get_experiment(logger.name).name == "lightning_logs"
349
+ client = mlflow_mock .tracking .MlflowClient ()
350
+ client .get_experiment_by_name .return_value = None
351
+ logger = _instantiate_logger (MLFlowLogger , save_dir = tmp_path )
352
+
353
+ _ = logger .experiment
354
+ logger ._mlflow_client .create_experiment .assert_called_with (name = "lightning_logs" , artifact_location = ANY )
355
+ # on MLFLowLogger `name` refers to the experiment id
356
+ # assert logger.experiment.get_experiment(logger.name).name == "lightning_logs"
0 commit comments