44
44
from neptune .new .types import File as NeptuneFile
45
45
except ModuleNotFoundError :
46
46
import neptune
47
- from neptune .exceptions import NeptuneLegacyProjectException
47
+ from neptune .exceptions import NeptuneLegacyProjectException , NeptuneOfflineModeFetchException
48
48
from neptune .run import Run
49
49
from neptune .types import File as NeptuneFile
50
50
else :
@@ -266,51 +266,64 @@ def __init__(
266
266
prefix : str = "training" ,
267
267
** neptune_run_kwargs ,
268
268
):
269
-
270
269
# verify if user passed proper init arguments
271
270
self ._verify_input_arguments (api_key , project , name , run , neptune_run_kwargs )
271
+ if neptune is None :
272
+ raise ModuleNotFoundError (
273
+ "You want to use the `Neptune` logger which is not installed yet, install it with"
274
+ " `pip install neptune-client`."
275
+ )
272
276
273
277
super ().__init__ ()
274
278
self ._log_model_checkpoints = log_model_checkpoints
275
279
self ._prefix = prefix
280
+ self ._run_name = name
281
+ self ._project_name = project
282
+ self ._api_key = api_key
283
+ self ._run_instance = run
284
+ self ._neptune_run_kwargs = neptune_run_kwargs
285
+ self ._run_short_id = None
276
286
277
- self ._run_instance = self ._init_run_instance (api_key , project , name , run , neptune_run_kwargs )
287
+ if self ._run_instance is not None :
288
+ self ._retrieve_run_data ()
278
289
279
- self ._run_short_id = self .run ._short_id # skipcq: PYL-W0212
290
+ # make sure that we've log integration version for outside `Run` instances
291
+ self ._run_instance [_INTEGRATION_VERSION_KEY ] = __version__
292
+
293
+ def _retrieve_run_data (self ):
280
294
try :
281
- self .run .wait ()
295
+ self ._run_instance .wait ()
296
+ self ._run_short_id = self .run ._short_id # skipcq: PYL-W0212
282
297
self ._run_name = self ._run_instance ["sys/name" ].fetch ()
283
298
except NeptuneOfflineModeFetchException :
284
299
self ._run_name = "offline-name"
285
300
286
- def _init_run_instance (self , api_key , project , name , run , neptune_run_kwargs ) -> Run :
287
- if run is not None :
288
- run_instance = run
289
- else :
290
- try :
291
- run_instance = neptune .init (
292
- project = project ,
293
- api_token = api_key ,
294
- name = name ,
295
- ** neptune_run_kwargs ,
296
- )
297
- except NeptuneLegacyProjectException as e :
298
- raise TypeError (
299
- f"""Project { project } has not been migrated to the new structure.
300
- You can still integrate it with the Neptune logger using legacy Python API
301
- available as part of neptune-contrib package:
302
- - https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
303
- """
304
- ) from e
305
-
306
- # make sure that we've log integration version for both newly created and outside `Run` instances
307
- run_instance [_INTEGRATION_VERSION_KEY ] = __version__
308
-
309
- # keep api_key and project, they will be required when resuming Run for pickled logger
310
- self ._api_key = api_key
311
- self ._project_name = run_instance ._project_name # skipcq: PYL-W0212
301
+ @property
302
+ def _neptune_init_args (self ):
303
+ args = {}
304
+ # Backward compatibility in case of previous version retrieval
305
+ try :
306
+ args = self ._neptune_run_kwargs
307
+ except AttributeError :
308
+ pass
309
+
310
+ if self ._project_name is not None :
311
+ args ["project" ] = self ._project_name
312
+
313
+ if self ._api_key is not None :
314
+ args ["api_token" ] = self ._api_key
312
315
313
- return run_instance
316
+ if self ._run_short_id is not None :
317
+ args ["run" ] = self ._run_short_id
318
+
319
+ # Backward compatibility in case of previous version retrieval
320
+ try :
321
+ if self ._run_name is not None :
322
+ args ["name" ] = self ._run_name
323
+ except AttributeError :
324
+ pass
325
+
326
+ return args
314
327
315
328
def _construct_path_with_prefix (self , * keys ) -> str :
316
329
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
@@ -379,7 +392,7 @@ def __getstate__(self):
379
392
380
393
def __setstate__ (self , state ):
381
394
self .__dict__ = state
382
- self ._run_instance = neptune .init (project = self ._project_name , api_token = self . _api_key , run = self . _run_short_id )
395
+ self ._run_instance = neptune .init (** self ._neptune_init_args )
383
396
384
397
@property
385
398
@rank_zero_experiment
@@ -412,8 +425,23 @@ def training_step(self, batch, batch_idx):
412
425
return self .run
413
426
414
427
@property
428
+ @rank_zero_experiment
415
429
def run (self ) -> Run :
416
- return self ._run_instance
430
+ try :
431
+ if not self ._run_instance :
432
+ self ._run_instance = neptune .init (** self ._neptune_init_args )
433
+ self ._retrieve_run_data ()
434
+ # make sure that we've log integration version for newly created
435
+ self ._run_instance [_INTEGRATION_VERSION_KEY ] = __version__
436
+
437
+ return self ._run_instance
438
+ except NeptuneLegacyProjectException as e :
439
+ raise TypeError (
440
+ f"Project { self ._project_name } has not been migrated to the new structure."
441
+ " You can still integrate it with the Neptune logger using legacy Python API"
442
+ " available as part of neptune-contrib package:"
443
+ " https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n "
444
+ ) from e
417
445
418
446
@rank_zero_only
419
447
def log_hyperparams (self , params : Union [Dict [str , Any ], Namespace ]) -> None : # skipcq: PYL-W0221
@@ -473,13 +501,13 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
473
501
474
502
for key , val in metrics .items ():
475
503
# `step` is ignored because Neptune expects strictly increasing step values which
476
- # Lighting does not always guarantee.
477
- self .experiment [key ].log (val )
504
+ # Lightning does not always guarantee.
505
+ self .run [key ].log (val )
478
506
479
507
@rank_zero_only
480
508
def finalize (self , status : str ) -> None :
481
509
if status :
482
- self .experiment [self ._construct_path_with_prefix ("status" )] = status
510
+ self .run [self ._construct_path_with_prefix ("status" )] = status
483
511
484
512
super ().finalize (status )
485
513
@@ -493,12 +521,14 @@ def save_dir(self) -> Optional[str]:
493
521
"""
494
522
return os .path .join (os .getcwd (), ".neptune" )
495
523
524
+ @rank_zero_only
496
525
def log_model_summary (self , model , max_depth = - 1 ):
497
526
model_str = str (ModelSummary (model = model , max_depth = max_depth ))
498
- self .experiment [self ._construct_path_with_prefix ("model/summary" )] = neptune .types .File .from_content (
527
+ self .run [self ._construct_path_with_prefix ("model/summary" )] = neptune .types .File .from_content (
499
528
content = model_str , extension = "txt"
500
529
)
501
530
531
+ @rank_zero_only
502
532
def after_save_checkpoint (self , checkpoint_callback : "ReferenceType[ModelCheckpoint]" ) -> None :
503
533
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
504
534
@@ -515,35 +545,33 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
515
545
if checkpoint_callback .last_model_path :
516
546
model_last_name = self ._get_full_model_name (checkpoint_callback .last_model_path , checkpoint_callback )
517
547
file_names .add (model_last_name )
518
- self .experiment [f"{ checkpoints_namespace } /{ model_last_name } " ].upload (checkpoint_callback .last_model_path )
548
+ self .run [f"{ checkpoints_namespace } /{ model_last_name } " ].upload (checkpoint_callback .last_model_path )
519
549
520
550
# save best k models
521
551
for key in checkpoint_callback .best_k_models .keys ():
522
552
model_name = self ._get_full_model_name (key , checkpoint_callback )
523
553
file_names .add (model_name )
524
- self .experiment [f"{ checkpoints_namespace } /{ model_name } " ].upload (key )
554
+ self .run [f"{ checkpoints_namespace } /{ model_name } " ].upload (key )
525
555
526
556
# log best model path and checkpoint
527
557
if checkpoint_callback .best_model_path :
528
- self .experiment [
529
- self ._construct_path_with_prefix ("model/best_model_path" )
530
- ] = checkpoint_callback .best_model_path
558
+ self .run [self ._construct_path_with_prefix ("model/best_model_path" )] = checkpoint_callback .best_model_path
531
559
532
560
model_name = self ._get_full_model_name (checkpoint_callback .best_model_path , checkpoint_callback )
533
561
file_names .add (model_name )
534
- self .experiment [f"{ checkpoints_namespace } /{ model_name } " ].upload (checkpoint_callback .best_model_path )
562
+ self .run [f"{ checkpoints_namespace } /{ model_name } " ].upload (checkpoint_callback .best_model_path )
535
563
536
564
# remove old models logged to experiment if they are not part of best k models at this point
537
- if self .experiment .exists (checkpoints_namespace ):
538
- exp_structure = self .experiment .get_structure ()
565
+ if self .run .exists (checkpoints_namespace ):
566
+ exp_structure = self .run .get_structure ()
539
567
uploaded_model_names = self ._get_full_model_names_from_exp_structure (exp_structure , checkpoints_namespace )
540
568
541
569
for file_to_drop in list (uploaded_model_names - file_names ):
542
- del self .experiment [f"{ checkpoints_namespace } /{ file_to_drop } " ]
570
+ del self .run [f"{ checkpoints_namespace } /{ file_to_drop } " ]
543
571
544
572
# log best model score
545
573
if checkpoint_callback .best_model_score :
546
- self .experiment [self ._construct_path_with_prefix ("model/best_model_score" )] = (
574
+ self .run [self ._construct_path_with_prefix ("model/best_model_score" )] = (
547
575
checkpoint_callback .best_model_score .cpu ().detach ().numpy ()
548
576
)
549
577
@@ -637,13 +665,11 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None
637
665
self ._signal_deprecated_api_usage ("log_artifact" , f"logger.run['{ key } ].log('path_to_file')" )
638
666
self .run [key ].log (destination )
639
667
640
- @rank_zero_only
641
668
def set_property (self , * args , ** kwargs ):
642
669
self ._signal_deprecated_api_usage (
643
670
"log_artifact" , f"logger.run['{ self ._prefix } /{ self .PARAMETERS_KEY } /key'].log(value)" , raise_exception = True
644
671
)
645
672
646
- @rank_zero_only
647
673
def append_tags (self , * args , ** kwargs ):
648
674
self ._signal_deprecated_api_usage (
649
675
"append_tags" , "logger.run['sys/tags'].add(['foo', 'bar'])" , raise_exception = True
0 commit comments