Skip to content

Commit ed84cef

Browse files
authored
Removed duplicated file extension when uploading model checkpoints with NeptuneLogger (#11015)
1 parent 5576fbc commit ed84cef

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
102102
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
103103

104104

105+
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
106+
107+
105108
### Deprecated
106109

107110
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))

pytorch_lightning/loggers/neptune.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,10 @@ def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Mo
553553
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
554554
if not model_path.startswith(expected_model_path):
555555
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
556-
return model_path[len(expected_model_path) :]
556+
# Remove extension from filepath
557+
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
558+
559+
return filepath
557560

558561
@classmethod
559562
def _get_full_model_names_from_exp_structure(cls, exp_structure: dict, namespace: str) -> Set[str]:

tests/loggers/test_neptune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,9 @@ def test__get_full_model_name(self):
397397
# given:
398398
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
399399
test_input_data = [
400-
("key.ext", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
400+
("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
401401
(
402-
"key/in/parts.ext",
402+
"key/in/parts",
403403
os.path.join("foo", "bar", "key/in/parts.ext"),
404404
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
405405
),

0 commit comments

Comments
 (0)