@@ -63,7 +63,7 @@ def __init__(
63
63
64
64
def scale_batch_size (self , trainer , pl_module ):
65
65
if trainer .fast_dev_run :
66
- rank_zero_warn ("Skiping batch size scaler since `fast_dev_run` is enabled." )
66
+ rank_zero_warn ("Skipping batch size scaler since `fast_dev_run` is enabled." )
67
67
return
68
68
69
69
if not lightning_hasattr (pl_module , self .batch_arg_name ):
@@ -209,7 +209,6 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
209
209
def _try_loop_run (self , trainer ):
210
210
if trainer .state .fn == TrainerFn .FITTING :
211
211
trainer .fit_loop .global_step = self ._dumped_params ["global_step" ]
212
- # trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
213
212
loop = trainer .fit_loop
214
213
elif trainer .state .fn == TrainerFn .VALIDATING :
215
214
loop = trainer .validate_loop
@@ -235,7 +234,6 @@ def _reset_dataloaders(trainer, pl_module):
235
234
236
235
def _dump_params (self , trainer ):
237
236
self ._dumped_params = {
238
- # "current_epoch": trainer.current_epoch,
239
237
"logger" : trainer .logger ,
240
238
"callbacks" : trainer .callbacks ,
241
239
}
@@ -280,7 +278,6 @@ def _restore_params(self, trainer):
280
278
trainer .callbacks = self ._dumped_params ["callbacks" ]
281
279
282
280
if trainer .state .fn == TrainerFn .FITTING :
283
- # trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
284
281
trainer .fit_loop .global_step = self ._dumped_params ["global_step" ]
285
282
loop = trainer .fit_loop
286
283
loop .max_steps = self ._dumped_params ["max_steps" ]
@@ -300,6 +297,9 @@ def _restore_params(self, trainer):
300
297
loop .verbose = self ._dumped_params ["loop_verbose" ]
301
298
302
299
def pre_early_exit (self , trainer ):
300
+ if trainer .fast_dev_run :
301
+ return
302
+
303
303
if trainer .state .fn == TrainerFn .FITTING :
304
304
trainer .should_stop = True
305
305
self ._dumped_params ["num_training_batches" ] = trainer .num_training_batches
@@ -315,6 +315,9 @@ def pre_early_exit(self, trainer):
315
315
trainer .num_predict_batches = [0 ]
316
316
317
317
def post_early_exit (self , trainer ):
318
+ if trainer .fast_dev_run :
319
+ return
320
+
318
321
if trainer .state .fn == TrainerFn .FITTING :
319
322
trainer .num_training_batches = self ._dumped_params ["num_training_batches" ]
320
323
loop = trainer .fit_loop
0 commit comments