44
44
from neptune .new .types import File as NeptuneFile
45
45
except ModuleNotFoundError :
46
46
import neptune
47
- from neptune .exceptions import NeptuneLegacyProjectException
47
+ from neptune .exceptions import NeptuneLegacyProjectException , NeptuneOfflineModeFetchException
48
48
from neptune .run import Run
49
49
from neptune .types import File as NeptuneFile
50
50
else :
@@ -273,44 +273,53 @@ def __init__(
273
273
super ().__init__ ()
274
274
self ._log_model_checkpoints = log_model_checkpoints
275
275
self ._prefix = prefix
276
+ self ._run_name = name
277
+ self ._project_name = project
278
+ self ._api_key = api_key
279
+ self ._run_instance = run
280
+ self ._neptune_run_kwargs = neptune_run_kwargs
281
+ self ._run_short_id = None
282
+
283
+ if self ._run_instance is not None :
284
+ self ._retrieve_run_data ()
276
285
277
- self ._run_instance = self ._init_run_instance (api_key , project , name , run , neptune_run_kwargs )
286
+ # make sure that we've log integration version for outside `Run` instances
287
+ self ._run_instance [_INTEGRATION_VERSION_KEY ] = __version__
278
288
279
- self . _run_short_id = self . run . _short_id # skipcq: PYL-W0212
289
+ def _retrieve_run_data ( self ):
280
290
try :
281
- self .run .wait ()
291
+ self ._run_instance .wait ()
292
+ self ._run_short_id = self .run ._short_id # skipcq: PYL-W0212
282
293
self ._run_name = self ._run_instance ["sys/name" ].fetch ()
283
294
except NeptuneOfflineModeFetchException :
284
295
self ._run_name = "offline-name"
285
296
286
- def _init_run_instance (self , api_key , project , name , run , neptune_run_kwargs ) -> Run :
287
- if run is not None :
288
- run_instance = run
289
- else :
290
- try :
291
- run_instance = neptune .init (
292
- project = project ,
293
- api_token = api_key ,
294
- name = name ,
295
- ** neptune_run_kwargs ,
296
- )
297
- except NeptuneLegacyProjectException as e :
298
- raise TypeError (
299
- f"""Project { project } has not been migrated to the new structure.
300
- You can still integrate it with the Neptune logger using legacy Python API
301
- available as part of neptune-contrib package:
302
- - https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
303
- """
304
- ) from e
305
-
306
- # make sure that we've log integration version for both newly created and outside `Run` instances
307
- run_instance [_INTEGRATION_VERSION_KEY ] = __version__
308
-
309
- # keep api_key and project, they will be required when resuming Run for pickled logger
310
- self ._api_key = api_key
311
- self ._project_name = run_instance ._project_name # skipcq: PYL-W0212
297
+ @property
298
+ def _neptune_init_args (self ):
299
+ args = {}
300
+ # Backward compatibility in case of previous version retrieval
301
+ try :
302
+ args = self ._neptune_run_kwargs
303
+ except AttributeError :
304
+ pass
305
+
306
+ if self ._project_name is not None :
307
+ args ["project" ] = self ._project_name
312
308
313
- return run_instance
309
+ if self ._api_key is not None :
310
+ args ["api_token" ] = self ._api_key
311
+
312
+ if self ._run_short_id is not None :
313
+ args ["run" ] = self ._run_short_id
314
+
315
+ # Backward compatibility in case of previous version retrieval
316
+ try :
317
+ if self ._run_name is not None :
318
+ args ["name" ] = self ._run_name
319
+ except AttributeError :
320
+ pass
321
+
322
+ return args
314
323
315
324
def _construct_path_with_prefix (self , * keys ) -> str :
316
325
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
@@ -379,7 +388,7 @@ def __getstate__(self):
379
388
380
389
def __setstate__ (self , state ):
381
390
self .__dict__ = state
382
- self ._run_instance = neptune .init (project = self ._project_name , api_token = self . _api_key , run = self . _run_short_id )
391
+ self ._run_instance = neptune .init (** self ._neptune_init_args )
383
392
384
393
@property
385
394
@rank_zero_experiment
@@ -412,8 +421,24 @@ def training_step(self, batch, batch_idx):
412
421
return self .run
413
422
414
423
@property
424
+ @rank_zero_experiment
415
425
def run (self ) -> Run :
416
- return self ._run_instance
426
+ try :
427
+ if not self ._run_instance :
428
+ self ._run_instance = neptune .init (** self ._neptune_init_args )
429
+ self ._retrieve_run_data ()
430
+ # make sure that we've log integration version for newly created
431
+ self ._run_instance [_INTEGRATION_VERSION_KEY ] = __version__
432
+
433
+ return self ._run_instance
434
+ except NeptuneLegacyProjectException as e :
435
+ raise TypeError (
436
+ f"""Project { self ._project_name } has not been migrated to the new structure.
437
+ You can still integrate it with the Neptune logger using legacy Python API
438
+ available as part of neptune-contrib package:
439
+ - https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
440
+ """
441
+ ) from e
417
442
418
443
@rank_zero_only
419
444
def log_hyperparams (self , params : Union [Dict [str , Any ], Namespace ]) -> None : # skipcq: PYL-W0221
@@ -474,12 +499,12 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
474
499
for key , val in metrics .items ():
475
500
# `step` is ignored because Neptune expects strictly increasing step values which
476
501
# Lightning does not always guarantee.
477
- self .experiment [key ].log (val )
502
+ self .run [key ].log (val )
478
503
479
504
@rank_zero_only
480
505
def finalize (self , status : str ) -> None :
481
506
if status :
482
- self .experiment [self ._construct_path_with_prefix ("status" )] = status
507
+ self .run [self ._construct_path_with_prefix ("status" )] = status
483
508
484
509
super ().finalize (status )
485
510
@@ -493,12 +518,14 @@ def save_dir(self) -> Optional[str]:
493
518
"""
494
519
return os .path .join (os .getcwd (), ".neptune" )
495
520
521
+ @rank_zero_only
496
522
def log_model_summary (self , model , max_depth = - 1 ):
497
523
model_str = str (ModelSummary (model = model , max_depth = max_depth ))
498
- self .experiment [self ._construct_path_with_prefix ("model/summary" )] = neptune .types .File .from_content (
524
+ self .run [self ._construct_path_with_prefix ("model/summary" )] = neptune .types .File .from_content (
499
525
content = model_str , extension = "txt"
500
526
)
501
527
528
+ @rank_zero_only
502
529
def after_save_checkpoint (self , checkpoint_callback : "ReferenceType[ModelCheckpoint]" ) -> None :
503
530
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
504
531
@@ -515,35 +542,33 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
515
542
if checkpoint_callback .last_model_path :
516
543
model_last_name = self ._get_full_model_name (checkpoint_callback .last_model_path , checkpoint_callback )
517
544
file_names .add (model_last_name )
518
- self .experiment [f"{ checkpoints_namespace } /{ model_last_name } " ].upload (checkpoint_callback .last_model_path )
545
+ self .run [f"{ checkpoints_namespace } /{ model_last_name } " ].upload (checkpoint_callback .last_model_path )
519
546
520
547
# save best k models
521
548
for key in checkpoint_callback .best_k_models .keys ():
522
549
model_name = self ._get_full_model_name (key , checkpoint_callback )
523
550
file_names .add (model_name )
524
- self .experiment [f"{ checkpoints_namespace } /{ model_name } " ].upload (key )
551
+ self .run [f"{ checkpoints_namespace } /{ model_name } " ].upload (key )
525
552
526
553
# log best model path and checkpoint
527
554
if checkpoint_callback .best_model_path :
528
- self .experiment [
529
- self ._construct_path_with_prefix ("model/best_model_path" )
530
- ] = checkpoint_callback .best_model_path
555
+ self .run [self ._construct_path_with_prefix ("model/best_model_path" )] = checkpoint_callback .best_model_path
531
556
532
557
model_name = self ._get_full_model_name (checkpoint_callback .best_model_path , checkpoint_callback )
533
558
file_names .add (model_name )
534
- self .experiment [f"{ checkpoints_namespace } /{ model_name } " ].upload (checkpoint_callback .best_model_path )
559
+ self .run [f"{ checkpoints_namespace } /{ model_name } " ].upload (checkpoint_callback .best_model_path )
535
560
536
561
# remove old models logged to experiment if they are not part of best k models at this point
537
- if self .experiment .exists (checkpoints_namespace ):
538
- exp_structure = self .experiment .get_structure ()
562
+ if self .run .exists (checkpoints_namespace ):
563
+ exp_structure = self .run .get_structure ()
539
564
uploaded_model_names = self ._get_full_model_names_from_exp_structure (exp_structure , checkpoints_namespace )
540
565
541
566
for file_to_drop in list (uploaded_model_names - file_names ):
542
- del self .experiment [f"{ checkpoints_namespace } /{ file_to_drop } " ]
567
+ del self .run [f"{ checkpoints_namespace } /{ file_to_drop } " ]
543
568
544
569
# log best model score
545
570
if checkpoint_callback .best_model_score :
546
- self .experiment [self ._construct_path_with_prefix ("model/best_model_score" )] = (
571
+ self .run [self ._construct_path_with_prefix ("model/best_model_score" )] = (
547
572
checkpoint_callback .best_model_score .cpu ().detach ().numpy ()
548
573
)
549
574
@@ -637,13 +662,11 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None
637
662
self ._signal_deprecated_api_usage ("log_artifact" , f"logger.run['{ key } ].log('path_to_file')" )
638
663
self .run [key ].log (destination )
639
664
640
- @rank_zero_only
641
665
def set_property (self , * args , ** kwargs ):
642
666
self ._signal_deprecated_api_usage (
643
667
"log_artifact" , f"logger.run['{ self ._prefix } /{ self .PARAMETERS_KEY } /key'].log(value)" , raise_exception = True
644
668
)
645
669
646
- @rank_zero_only
647
670
def append_tags (self , * args , ** kwargs ):
648
671
self ._signal_deprecated_api_usage (
649
672
"append_tags" , "logger.run['sys/tags'].add(['foo', 'bar'])" , raise_exception = True
0 commit comments