Skip to content

Commit 7d534bd

Browse files
Raalskyawaelchli
authored andcommitted
Fixed uploading best model checkpoint in NeptuneLogger (#10369)
1 parent c20b7fb commit 7d534bd

File tree

3 files changed

+35
-19
lines changed

3 files changed

+35
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414
- Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839))
1515

1616

17+
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
18+
19+
1720
## [1.5.4] - 2021-11-30
1821

1922
### Fixed

pytorch_lightning/loggers/neptune.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,16 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
523523
file_names.add(model_name)
524524
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)
525525

526+
# log best model path and checkpoint
527+
if checkpoint_callback.best_model_path:
528+
self.experiment[
529+
self._construct_path_with_prefix("model/best_model_path")
530+
] = checkpoint_callback.best_model_path
531+
532+
model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
533+
file_names.add(model_name)
534+
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
535+
526536
# remove old models logged to experiment if they are not part of best k models at this point
527537
if self.experiment.exists(checkpoints_namespace):
528538
exp_structure = self.experiment.get_structure()
@@ -531,11 +541,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
531541
for file_to_drop in list(uploaded_model_names - file_names):
532542
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]
533543

534-
# log best model path and best model score
535-
if checkpoint_callback.best_model_path:
536-
self.experiment[
537-
self._construct_path_with_prefix("model/best_model_path")
538-
] = checkpoint_callback.best_model_path
544+
# log best model score
539545
if checkpoint_callback.best_model_score:
540546
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
541547
checkpoint_callback.best_model_score.cpu().detach().numpy()
@@ -544,7 +550,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
544550
@staticmethod
545551
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str:
546552
"""Returns model name which is string `modle_path` appended to `checkpoint_callback.dirpath`."""
547-
expected_model_path = f"{checkpoint_callback.dirpath}/"
553+
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
548554
if not model_path.startswith(expected_model_path):
549555
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
550556
return model_path[len(expected_model_path) :]

tests/loggers/test_neptune.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,15 @@ def test_after_save_checkpoint(self, neptune):
276276
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
277277
api_key="test", project="project", **prefix
278278
)
279+
models_root_dir = os.path.join("path", "to", "models")
279280
cb_mock = MagicMock(
280-
dirpath="path/to/models",
281-
last_model_path="path/to/models/last",
281+
dirpath=models_root_dir,
282+
last_model_path=os.path.join(models_root_dir, "last"),
282283
best_k_models={
283-
"path/to/models/model1": None,
284-
"path/to/models/model2/with/slashes": None,
284+
f"{os.path.join(models_root_dir, 'model1')}": None,
285+
f"{os.path.join(models_root_dir, 'model2/with/slashes')}": None,
285286
},
286-
best_model_path="path/to/models/best_model",
287+
best_model_path=os.path.join(models_root_dir, "best_model"),
287288
best_model_score=None,
288289
)
289290

@@ -292,19 +293,21 @@ def test_after_save_checkpoint(self, neptune):
292293

293294
# then:
294295
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
295-
self.assertEqual(run_instance_mock.__getitem__.call_count, 3)
296-
self.assertEqual(run_attr_mock.upload.call_count, 3)
296+
self.assertEqual(run_instance_mock.__getitem__.call_count, 4)
297+
self.assertEqual(run_attr_mock.upload.call_count, 4)
297298
run_instance_mock.__setitem__.assert_called_once_with(
298-
f"{model_key_prefix}/best_model_path", "path/to/models/best_model"
299+
f"{model_key_prefix}/best_model_path", os.path.join(models_root_dir, "best_model")
299300
)
300301
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/last")
301302
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
302303
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
304+
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/best_model")
303305
run_attr_mock.upload.assert_has_calls(
304306
[
305-
call("path/to/models/last"),
306-
call("path/to/models/model1"),
307-
call("path/to/models/model2/with/slashes"),
307+
call(os.path.join(models_root_dir, "last")),
308+
call(os.path.join(models_root_dir, "model1")),
309+
call(os.path.join(models_root_dir, "model2/with/slashes")),
310+
call(os.path.join(models_root_dir, "best_model")),
308311
]
309312
)
310313

@@ -394,8 +397,12 @@ def test__get_full_model_name(self):
394397
# given:
395398
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
396399
test_input_data = [
397-
("key.ext", "foo/bar/key.ext", SimpleCheckpoint(dirpath="foo/bar")),
398-
("key/in/parts.ext", "foo/bar/key/in/parts.ext", SimpleCheckpoint(dirpath="foo/bar")),
400+
("key.ext", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
401+
(
402+
"key/in/parts.ext",
403+
os.path.join("foo", "bar", "key/in/parts.ext"),
404+
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
405+
),
399406
]
400407

401408
# expect:

0 commit comments

Comments
 (0)