Skip to content

Commit 86477b0

Browse files
committed
Fix type hint for filepath
1 parent e0f2e04 commit 86477b0

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
2929
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
3030
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
31-
from pytorch_lightning.utilities.types import STEP_OUTPUT
31+
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
3232

3333
if _NATIVE_AMP_AVAILABLE:
3434
from torch.cuda.amp import GradScaler
@@ -393,7 +393,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
393393
with self.training_type_plugin.model_sharded_context():
394394
yield
395395

396-
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
396+
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
397397
"""Save model/training states as a checkpoint file through state-dump and file-write.
398398
399399
Args:

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.overrides.base import unwrap_lightning_module
2727
from pytorch_lightning.plugins import TorchCheckpointIO
2828
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
29-
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
29+
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT
3030

3131
TBroadcast = TypeVar("T")
3232

@@ -259,7 +259,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
259259
model = self.lightning_module
260260
return model.state_dict()
261261

262-
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
262+
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
263263
"""Save model/training states as a checkpoint file through state-dump and file-write.
264264
265265
Args:
@@ -269,7 +269,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
269269
if self.should_rank_save_checkpoint:
270270
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
271271

272-
def remove_checkpoint(self, filepath: str) -> None:
272+
def remove_checkpoint(self, filepath: _PATH) -> None:
273273
"""Remove checkpoint filepath from the filesystem.
274274
275275
Args:

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.imports import _fault_tolerant_training
29+
from pytorch_lightning.utilities.types import _PATH
2930
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
3031

3132
if _OMEGACONF_AVAILABLE:
@@ -430,7 +431,7 @@ def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
430431
ckpt_number = max_suffix if max_suffix is not None else 0
431432
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"
432433

433-
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
434+
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
434435
"""Save model/training states as a checkpoint file through state-dump and file-write.
435436
436437
Args:

pytorch_lightning/trainer/properties.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from pytorch_lightning.utilities.cloud_io import get_filesystem
5050
from pytorch_lightning.utilities.model_helpers import is_overridden
51+
from pytorch_lightning.utilities.types import _PATH
5152

5253

5354
class TrainerProperties(ABC):
@@ -388,7 +389,7 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
388389
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
389390
return self.checkpoint_connector.resume_checkpoint_path
390391

391-
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
392+
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
392393
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
393394

394395
"""

0 commit comments

Comments
 (0)