29
29
from pytorch_lightning .callbacks .base import Callback
30
30
from pytorch_lightning .loggers .base import DummyLogger
31
31
from pytorch_lightning .trainer .states import TrainerFn
32
+ from pytorch_lightning .tuner .tuning import _TunerExitException
32
33
from pytorch_lightning .utilities .cloud_io import get_filesystem
33
34
from pytorch_lightning .utilities .data import has_len_all_ranks
34
35
from pytorch_lightning .utilities .distributed import rank_zero_info
@@ -46,7 +47,6 @@ def __init__(
46
47
init_val = 2 ,
47
48
max_trials = 25 ,
48
49
batch_arg_name = "batch_size" ,
49
- early_exit = False ,
50
50
):
51
51
"""Callback try to find the largest batch size for a given model that does not give an out of memory (OOM)
52
52
error. It works with both training and evalation. All you need to do is add it as a callback inside Trainer
@@ -56,7 +56,7 @@ def __init__(
56
56
Args:
57
57
mode: search strategy to update the batch size:
58
58
59
- - ``'power'`` (default) : Keep multiplying the batch size by 2, until we get an OOM error.
59
+ - ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
60
60
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
61
61
do a binary search between the last successful batch size and the batch size that failed.
62
62
@@ -76,9 +76,6 @@ def __init__(
76
76
- ``model``
77
77
- ``model.hparams``
78
78
- ``trainer.datamodule`` (the datamodule passed to the tune method)
79
-
80
- early_exit: whether to continue with the training/evaluation or stop after
81
- an optimal batch size has been found.
82
79
"""
83
80
supported_modes = ("power" , "binsearch" )
84
81
mode = mode .lower ()
@@ -91,7 +88,8 @@ def __init__(
91
88
self .max_trials = max_trials
92
89
self .batch_arg_name = batch_arg_name
93
90
self .optimal_batch_size = init_val
94
- self .early_exit = early_exit
91
+
92
+ self ._early_exit = False
95
93
96
94
def scale_batch_size (self , trainer , pl_module ):
97
95
if trainer .fast_dev_run :
@@ -165,6 +163,9 @@ def scale_batch_size(self, trainer, pl_module):
165
163
print (f"new batch size: { new_size } " )
166
164
self .optimal_batch_size = new_size
167
165
166
+ if self ._early_exit :
167
+ raise _TunerExitException ()
168
+
168
169
def _run_power_scaling (self , trainer , pl_module , new_size ):
169
170
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
170
171
for _ in range (self .max_trials ):
@@ -332,99 +333,21 @@ def _restore_params(self, trainer):
332
333
if "loop_verbose" in self ._dumped_params :
333
334
loop .verbose = self ._dumped_params ["loop_verbose" ]
334
335
335
- def pre_early_exit (self , trainer ):
336
- if trainer .fast_dev_run :
337
- return
338
-
339
- # this is required to stop the respective loops
340
- if trainer .state .fn == TrainerFn .FITTING :
341
- self ._dumped_params ["num_training_batches" ] = trainer .num_training_batches
342
- trainer .num_training_batches = 0
343
- elif trainer .state .fn == TrainerFn .VALIDATING :
344
- self ._dumped_params ["num_val_batches" ] = trainer .num_val_batches
345
- trainer .num_val_batches = [0 ]
346
- elif trainer .state .fn == TrainerFn .TESTING :
347
- self ._dumped_params ["num_test_batches" ] = trainer .num_test_batches
348
- trainer .num_test_batches = [0 ]
349
- elif trainer .state .fn == TrainerFn .PREDICTING :
350
- self ._dumped_params ["num_predict_batches" ] = trainer .num_predict_batches
351
- trainer .num_predict_batches = [0 ]
352
-
353
- def post_early_exit (self , trainer ):
354
- if trainer .fast_dev_run :
355
- return
356
-
357
- # restore the state used to stop the respective loop
358
- if trainer .state .fn == TrainerFn .FITTING :
359
- trainer .num_training_batches = self ._dumped_params ["num_training_batches" ]
360
- loop = trainer .fit_loop
361
- if trainer .state .fn == TrainerFn .VALIDATING :
362
- trainer .num_val_batches = self ._dumped_params ["num_val_batches" ]
363
- loop = trainer .validate_loop
364
- if trainer .state .fn == TrainerFn .TESTING :
365
- trainer .num_test_batches = self ._dumped_params ["num_test_batches" ]
366
- loop = trainer .test_loop
367
- if trainer .state .fn == TrainerFn .PREDICTING :
368
- trainer .num_predict_batches = self ._dumped_params ["num_predict_batches" ]
369
- loop = trainer .predict_loop
370
-
371
- loop .load_state_dict (self ._dumped_params ["loop_state_dict" ], force_load_progress = True )
372
- trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
373
-
374
336
def on_fit_start (self , trainer , pl_module ):
375
337
self .scale_batch_size (trainer , pl_module )
376
338
377
- if self .early_exit :
378
- self .pre_early_exit (trainer )
379
- else :
380
- trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
381
-
382
339
def on_validation_start (self , trainer , pl_module ):
383
340
if trainer .sanity_checking or trainer .state .fn != TrainerFn .VALIDATING :
384
341
return
385
342
386
343
self .scale_batch_size (trainer , pl_module )
387
344
388
- if self .early_exit :
389
- self .pre_early_exit (trainer )
390
- else :
391
- trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
392
-
393
345
def on_test_start (self , trainer , pl_module ):
394
346
self .scale_batch_size (trainer , pl_module )
395
347
396
- if self .early_exit :
397
- self .pre_early_exit (trainer )
398
- else :
399
- trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
400
-
401
348
def on_predict_start (self , trainer , pl_module ):
402
349
self .scale_batch_size (trainer , pl_module )
403
350
404
- if self .early_exit :
405
- self .pre_early_exit (trainer )
406
- else :
407
- trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
408
-
409
- def on_fit_end (self , trainer , pl_module ):
410
- if self .early_exit :
411
- self .post_early_exit (trainer )
412
-
413
- def on_validation_end (self , trainer , pl_module ):
414
- if trainer .sanity_checking or trainer .state .fn != TrainerFn .VALIDATING :
415
- return
416
-
417
- if self .early_exit :
418
- self .post_early_exit (trainer )
419
-
420
- def on_test_end (self , trainer , pl_module ):
421
- if self .early_exit :
422
- self .post_early_exit (trainer )
423
-
424
- def on_predict_end (self , trainer , pl_module ):
425
- if self .early_exit :
426
- self .post_early_exit (trainer )
427
-
428
351
def _adjust_batch_size (
429
352
self ,
430
353
trainer : "pl.Trainer" ,
0 commit comments