Skip to content

Commit 93cda24

Browse files
awaelchliRaalsky
authored andcommitted
Removed duplicated file extension when uploading model checkpoints with NeptuneLogger (#11015)
Co-authored-by: Rafał Jankowski <[email protected]>
1 parent 1f63923 commit 93cda24

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Fixed
1010

1111
- Fixed a bug where the DeepSpeedPlugin arguments `cpu_checkpointing` and `contiguous_memory_optimization` were not being forwarded to deepspeed correctly ([#10874](https://github.com/PyTorchLightning/pytorch-lightning/issues/10874))
12-
13-
14-
-
12+
- Fixed an issue with `NeptuneLogger` causing checkpoints to be uploaded with a duplicated file extension ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/issues/11015))
1513

1614

1715
## [1.5.5] - 2021-12-07

pytorch_lightning/loggers/neptune.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,10 @@ def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Mo
553553
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
554554
if not model_path.startswith(expected_model_path):
555555
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
556-
return model_path[len(expected_model_path) :]
556+
# Remove extension from filepath
557+
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
558+
559+
return filepath
557560

558561
@classmethod
559562
def _get_full_model_names_from_exp_structure(cls, exp_structure: dict, namespace: str) -> Set[str]:

tests/loggers/test_neptune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,9 @@ def test__get_full_model_name(self):
397397
# given:
398398
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
399399
test_input_data = [
400-
("key.ext", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
400+
("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
401401
(
402-
"key/in/parts.ext",
402+
"key/in/parts",
403403
os.path.join("foo", "bar", "key/in/parts.ext"),
404404
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
405405
),

0 commit comments

Comments
 (0)