Skip to content

Commit 12f1c37

Browse files
committed
Fixed NeptuneLogger when using with DDP
1 parent ed84cef commit 12f1c37

File tree

3 files changed

+75
-52
lines changed

3 files changed

+75
-52
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
255255

256256
### Fixed
257257

258-
-
258+
- Fixed `NeptuneLogger` when using DDP ([#11030](https://github.com/PyTorchLightning/pytorch-lightning/pull/11030))
259259

260260

261261
-

pytorch_lightning/loggers/neptune.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from neptune.new.types import File as NeptuneFile
4545
except ModuleNotFoundError:
4646
import neptune
47-
from neptune.exceptions import NeptuneLegacyProjectException
47+
from neptune.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException
4848
from neptune.run import Run
4949
from neptune.types import File as NeptuneFile
5050
else:
@@ -273,44 +273,53 @@ def __init__(
273273
super().__init__()
274274
self._log_model_checkpoints = log_model_checkpoints
275275
self._prefix = prefix
276+
self._run_name = name
277+
self._project_name = project
278+
self._api_key = api_key
279+
self._run_instance = run
280+
self._neptune_run_kwargs = neptune_run_kwargs
281+
self._run_short_id = None
282+
283+
if self._run_instance is not None:
284+
self._retrieve_run_data()
276285

277-
self._run_instance = self._init_run_instance(api_key, project, name, run, neptune_run_kwargs)
286+
# make sure that we've log integration version for outside `Run` instances
287+
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__
278288

279-
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
289+
def _retrieve_run_data(self):
280290
try:
281-
self.run.wait()
291+
self._run_instance.wait()
292+
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
282293
self._run_name = self._run_instance["sys/name"].fetch()
283294
except NeptuneOfflineModeFetchException:
284295
self._run_name = "offline-name"
285296

286-
def _init_run_instance(self, api_key, project, name, run, neptune_run_kwargs) -> Run:
287-
if run is not None:
288-
run_instance = run
289-
else:
290-
try:
291-
run_instance = neptune.init(
292-
project=project,
293-
api_token=api_key,
294-
name=name,
295-
**neptune_run_kwargs,
296-
)
297-
except NeptuneLegacyProjectException as e:
298-
raise TypeError(
299-
f"""Project {project} has not been migrated to the new structure.
300-
You can still integrate it with the Neptune logger using legacy Python API
301-
available as part of neptune-contrib package:
302-
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
303-
"""
304-
) from e
305-
306-
# make sure that we've log integration version for both newly created and outside `Run` instances
307-
run_instance[_INTEGRATION_VERSION_KEY] = __version__
308-
309-
# keep api_key and project, they will be required when resuming Run for pickled logger
310-
self._api_key = api_key
311-
self._project_name = run_instance._project_name # skipcq: PYL-W0212
297+
@property
298+
def _neptune_init_args(self):
299+
args = {}
300+
# Backward compatibility in case of previous version retrieval
301+
try:
302+
args = self._neptune_run_kwargs
303+
except AttributeError:
304+
pass
305+
306+
if self._project_name is not None:
307+
args["project"] = self._project_name
312308

313-
return run_instance
309+
if self._api_key is not None:
310+
args["api_token"] = self._api_key
311+
312+
if self._run_short_id is not None:
313+
args["run"] = self._run_short_id
314+
315+
# Backward compatibility in case of previous version retrieval
316+
try:
317+
if self._run_name is not None:
318+
args["name"] = self._run_name
319+
except AttributeError:
320+
pass
321+
322+
return args
314323

315324
def _construct_path_with_prefix(self, *keys) -> str:
316325
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
@@ -379,7 +388,7 @@ def __getstate__(self):
379388

380389
def __setstate__(self, state):
381390
self.__dict__ = state
382-
self._run_instance = neptune.init(project=self._project_name, api_token=self._api_key, run=self._run_short_id)
391+
self._run_instance = neptune.init(**self._neptune_init_args)
383392

384393
@property
385394
@rank_zero_experiment
@@ -412,8 +421,24 @@ def training_step(self, batch, batch_idx):
412421
return self.run
413422

414423
@property
424+
@rank_zero_experiment
415425
def run(self) -> Run:
416-
return self._run_instance
426+
try:
427+
if not self._run_instance:
428+
self._run_instance = neptune.init(**self._neptune_init_args)
429+
self._retrieve_run_data()
430+
# make sure that we've log integration version for newly created
431+
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__
432+
433+
return self._run_instance
434+
except NeptuneLegacyProjectException as e:
435+
raise TypeError(
436+
f"""Project {self._project_name} has not been migrated to the new structure.
437+
You can still integrate it with the Neptune logger using legacy Python API
438+
available as part of neptune-contrib package:
439+
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
440+
"""
441+
) from e
417442

418443
@rank_zero_only
419444
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # skipcq: PYL-W0221
@@ -474,12 +499,12 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
474499
for key, val in metrics.items():
475500
# `step` is ignored because Neptune expects strictly increasing step values which
476501
# Lightning does not always guarantee.
477-
self.experiment[key].log(val)
502+
self.run[key].log(val)
478503

479504
@rank_zero_only
480505
def finalize(self, status: str) -> None:
481506
if status:
482-
self.experiment[self._construct_path_with_prefix("status")] = status
507+
self.run[self._construct_path_with_prefix("status")] = status
483508

484509
super().finalize(status)
485510

@@ -493,12 +518,14 @@ def save_dir(self) -> Optional[str]:
493518
"""
494519
return os.path.join(os.getcwd(), ".neptune")
495520

521+
@rank_zero_only
496522
def log_model_summary(self, model, max_depth=-1):
497523
model_str = str(ModelSummary(model=model, max_depth=max_depth))
498-
self.experiment[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
524+
self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
499525
content=model_str, extension="txt"
500526
)
501527

528+
@rank_zero_only
502529
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
503530
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
504531
@@ -515,35 +542,33 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
515542
if checkpoint_callback.last_model_path:
516543
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
517544
file_names.add(model_last_name)
518-
self.experiment[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
545+
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
519546

520547
# save best k models
521548
for key in checkpoint_callback.best_k_models.keys():
522549
model_name = self._get_full_model_name(key, checkpoint_callback)
523550
file_names.add(model_name)
524-
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)
551+
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
525552

526553
# log best model path and checkpoint
527554
if checkpoint_callback.best_model_path:
528-
self.experiment[
529-
self._construct_path_with_prefix("model/best_model_path")
530-
] = checkpoint_callback.best_model_path
555+
self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path
531556

532557
model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
533558
file_names.add(model_name)
534-
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
559+
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
535560

536561
# remove old models logged to experiment if they are not part of best k models at this point
537-
if self.experiment.exists(checkpoints_namespace):
538-
exp_structure = self.experiment.get_structure()
562+
if self.run.exists(checkpoints_namespace):
563+
exp_structure = self.run.get_structure()
539564
uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)
540565

541566
for file_to_drop in list(uploaded_model_names - file_names):
542-
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]
567+
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]
543568

544569
# log best model score
545570
if checkpoint_callback.best_model_score:
546-
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
571+
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
547572
checkpoint_callback.best_model_score.cpu().detach().numpy()
548573
)
549574

@@ -637,13 +662,11 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None
637662
self._signal_deprecated_api_usage("log_artifact", f"logger.run['{key}].log('path_to_file')")
638663
self.run[key].log(destination)
639664

640-
@rank_zero_only
641665
def set_property(self, *args, **kwargs):
642666
self._signal_deprecated_api_usage(
643667
"log_artifact", f"logger.run['{self._prefix}/{self.PARAMETERS_KEY}/key'].log(value)", raise_exception=True
644668
)
645669

646-
@rank_zero_only
647670
def append_tags(self, *args, **kwargs):
648671
self._signal_deprecated_api_usage(
649672
"append_tags", "logger.run['sys/tags'].add(['foo', 'bar'])", raise_exception=True

tests/loggers/test_neptune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def tmpdir_unittest_fixture(request, tmpdir):
7777
class TestNeptuneLogger(unittest.TestCase):
7878
def test_neptune_online(self, neptune):
7979
logger = NeptuneLogger(api_key="test", project="project")
80-
created_run_mock = logger._run_instance
80+
created_run_mock = logger.run
8181

8282
self.assertEqual(logger._run_instance, created_run_mock)
8383
self.assertEqual(logger.name, "Run test name")
@@ -109,7 +109,7 @@ def test_neptune_pickling(self, neptune):
109109
pickled_logger = pickle.dumps(logger)
110110
unpickled = pickle.loads(pickled_logger)
111111

112-
neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42")
112+
neptune.init.assert_called_once_with(name="Test name", run=unpickleable_run._short_id)
113113
self.assertIsNotNone(unpickled.experiment)
114114

115115
@patch("pytorch_lightning.loggers.neptune.Run", Run)
@@ -360,7 +360,7 @@ def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock):
360360
logger = NeptuneLogger(api_key="test", project="project")
361361

362362
# test deprecated functions which will be shut down in pytorch-lightning 1.7.0
363-
attr_mock = logger._run_instance.__getitem__
363+
attr_mock = logger.run.__getitem__
364364
attr_mock.reset_mock()
365365
fake_image = {}
366366

0 commit comments

Comments
 (0)