Skip to content

Commit ae13e0c

Browse files
committed
Refactor with_gpu type with simple typing
Normal and HPC load now use common GPU type check (Lightning-AI#5300). Now that there is no needs of accepting both bool and int.
1 parent abbfb05 commit ae13e0c

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,22 @@ def restore_states(
114114
self,
115115
model: LightningModule,
116116
checkpoint_path: str,
117-
with_gpu: Optional[Union[int, bool]],
117+
on_gpu: bool,
118118
) -> Dict[str, Any]:
119119
"""
120120
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
121121
All restored states are listed in return value description of `dump_checkpoint`.
122-
`with_gpu=trainer.on_gpu` works as normal restore, `with_gpu=trainer.root_gpu` works as hpc restore.
123122
124123
Args:
125-
with_gpu: bool for `on_gpu`, Optional[int] for `trainer.root_gpu`.
124+
on_gpu: Whether trainer is on GPU or not.
126125
"""
127126
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
128127
checkpoint: Dict[str, Any] = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
129128

130129
# restore states
131130
if self.trainer.datamodule is not None:
132131
self.trainer.datamodule.on_load_checkpoint(checkpoint)
133-
self.restore_model_state(checkpoint, model, with_gpu)
132+
self.restore_model_state(checkpoint, model, on_gpu)
134133
self.restore_training_state(checkpoint)
135134

136135
return checkpoint
@@ -139,7 +138,7 @@ def restore_model_state(
139138
self,
140139
checkpoint: Dict[str, Any],
141140
model: LightningModule,
142-
with_gpu: Union[bool, Optional[int]]
141+
on_gpu: bool,
143142
) -> None:
144143
"""
145144
Restore model state.
@@ -151,7 +150,7 @@ def restore_model_state(
151150
model.load_state_dict(checkpoint['state_dict'])
152151

153152
# moves the model to the GPU
154-
if (with_gpu is True) or ((not isinstance(with_gpu, bool)) and (with_gpu is not None)):
153+
if (on_gpu is True) or ((not isinstance(on_gpu, bool)) and (on_gpu is not None)):
155154
model.cuda(self.trainer.root_gpu)
156155

157156
def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:

0 commit comments

Comments
 (0)