Skip to content

Fix type hint for filepath {save/load/remove}_checkpoint #9434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _NATIVE_AMP_AVAILABLE:
from torch.cuda.amp import GradScaler
Expand Down Expand Up @@ -393,7 +393,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
with self.training_type_plugin.model_sharded_context():
yield

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT

TBroadcast = TypeVar("T")

Expand Down Expand Up @@ -259,7 +259,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
model = self.lightning_module
return model.state_dict()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand All @@ -269,7 +269,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
if self.should_rank_save_checkpoint:
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)

def remove_checkpoint(self, filepath: str) -> None:
def remove_checkpoint(self, filepath: _PATH) -> None:
"""Remove checkpoint filepath from the filesystem.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _OMEGACONF_AVAILABLE:
Expand Down Expand Up @@ -430,7 +431,7 @@ def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
ckpt_number = max_suffix if max_suffix is not None else 0
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"

def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH


class TrainerProperties(ABC):
Expand Down Expand Up @@ -388,7 +389,7 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
return self.checkpoint_connector.resume_checkpoint_path

def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
self.checkpoint_connector.save_checkpoint(filepath, weights_only)

"""
Expand Down