@@ -114,23 +114,22 @@ def restore_states(
114
114
self ,
115
115
model : LightningModule ,
116
116
checkpoint_path : str ,
117
- with_gpu : Optional [ Union [ int , bool ]] ,
117
+ on_gpu : bool ,
118
118
) -> Dict [str , Any ]:
119
119
"""
120
120
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
121
121
All restored states are listed in return value description of `dump_checkpoint`.
122
- `with_gpu=trainer.on_gpu` works as normal restore, `with_gpu=trainer.root_gpu` works as hpc restore.
123
122
124
123
Args:
125
- with_gpu: bool for `on_gpu`, Optional[int] for `trainer.root_gpu` .
124
+ on_gpu: Whether trainer is on GPU or not .
126
125
"""
127
126
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
128
127
checkpoint : Dict [str , Any ] = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
129
128
130
129
# restore states
131
130
if self .trainer .datamodule is not None :
132
131
self .trainer .datamodule .on_load_checkpoint (checkpoint )
133
- self .restore_model_state (checkpoint , model , with_gpu )
132
+ self .restore_model_state (checkpoint , model , on_gpu )
134
133
self .restore_training_state (checkpoint )
135
134
136
135
return checkpoint
@@ -139,7 +138,7 @@ def restore_model_state(
139
138
self ,
140
139
checkpoint : Dict [str , Any ],
141
140
model : LightningModule ,
142
- with_gpu : Union [ bool , Optional [ int ]]
141
+ on_gpu : bool ,
143
142
) -> None :
144
143
"""
145
144
Restore model state.
@@ -151,7 +150,7 @@ def restore_model_state(
151
150
model .load_state_dict (checkpoint ['state_dict' ])
152
151
153
152
# moves the model to the GPU
154
- if (with_gpu is True ) or ((not isinstance (with_gpu , bool )) and (with_gpu is not None )):
153
+ if (on_gpu is True ) or ((not isinstance (on_gpu , bool )) and (on_gpu is not None )):
155
154
model .cuda (self .trainer .root_gpu )
156
155
157
156
def restore_training_state (self , checkpoint : Dict [str , Any ]) -> None :
0 commit comments