@@ -48,10 +48,42 @@ def __init__(
48
48
batch_arg_name = "batch_size" ,
49
49
early_exit = False ,
50
50
):
51
+ """Callback try to find the largest batch size for a given model that does not give an out of memory (OOM)
52
+ error. It works with both training and evalation. All you need to do is add it as a callback inside Trainer
53
+ and call ``trainer.fit/validate/test/predict()``. Internally it calls the respective step function
54
+ ``steps_per_trial`` times for each batch size until one of the batch size generates and OOM error.
51
55
56
+ Args:
57
+ mode: search strategy to update the batch size:
58
+
59
+ - ``'power'`` (default): Keep multiplying the batch size by 2, until we get an OOM error.
60
+ - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
61
+ do a binary search between the last successful batch size and the batch size that failed.
62
+
63
+ steps_per_trial: number of steps to run with a given batch size.
64
+ Ideally 1 should be enough to test if a OOM error occurs,
65
+ however in practice a few are needed.
66
+
67
+ init_val: initial batch size to start the search with.
68
+
69
+ max_trials: max number of increase in batch size done before
70
+ algorithm is terminated
71
+
72
+ batch_arg_name: name of the attribute that stores the batch size.
73
+ It is expected that the user has provided a model or datamodule that has a hyperparameter
74
+ with that name. We will look for this attribute name in the following places
75
+
76
+ - ``model``
77
+ - ``model.hparams``
78
+ - ``trainer.datamodule`` (the datamodule passed to the tune method)
79
+
80
+ early_exit: whether to continue with the training/evaluation or stop after
81
+ an optimal batch size has been found.
82
+ """
83
+ supported_modes = ("power" , "binsearch" )
52
84
mode = mode .lower ()
53
- if mode not in ( "power" , "binsearch" ) :
54
- raise MisconfigurationException ("`mode` should be either 'power' or 'binsearch' " )
85
+ if mode not in supported_modes :
86
+ raise MisconfigurationException (f "`mode` should be one of { supported_modes } " )
55
87
56
88
self .mode = mode
57
89
self .steps_per_trial = steps_per_trial
@@ -121,6 +153,10 @@ def scale_batch_size(self, trainer, pl_module):
121
153
if fs .exists (save_path ):
122
154
fs .rm (save_path )
123
155
156
+ # global step and current epoch are incremented before saved in checkpoint
157
+ trainer .fit_loop .global_step -= 1
158
+ trainer .fit_loop .current_epoch -= 1
159
+
124
160
self ._restore_params (trainer )
125
161
126
162
if trainer .progress_bar_callback :
@@ -165,7 +201,7 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
165
201
while True :
166
202
garbage_collection_cuda ()
167
203
try :
168
- # Try fit
204
+ # run loop
169
205
self ._try_loop_run (trainer )
170
206
count += 1
171
207
if count > self .max_trials :
@@ -217,7 +253,7 @@ def _try_loop_run(self, trainer):
217
253
elif trainer .state .fn == TrainerFn .PREDICTING :
218
254
loop = trainer .predict_loop
219
255
220
- loop .load_state_dict (deepcopy (self ._dumped_params ["loop_state_dict" ]))
256
+ loop .load_state_dict (deepcopy (self ._dumped_params ["loop_state_dict" ]), force_load_progress = True )
221
257
loop .run ()
222
258
223
259
@staticmethod
@@ -292,16 +328,16 @@ def _restore_params(self, trainer):
292
328
loop = trainer .predict_loop
293
329
trainer .limit_predict_batches = self ._dumped_params ["limit_predict_batches" ]
294
330
295
- loop .load_state_dict (deepcopy (self ._dumped_params ["loop_state_dict" ]))
331
+ loop .load_state_dict (deepcopy (self ._dumped_params ["loop_state_dict" ]), force_load_progress = True )
296
332
if "loop_verbose" in self ._dumped_params :
297
333
loop .verbose = self ._dumped_params ["loop_verbose" ]
298
334
299
335
def pre_early_exit (self , trainer ):
300
336
if trainer .fast_dev_run :
301
337
return
302
338
339
+ # this is required to stop the respective loops
303
340
if trainer .state .fn == TrainerFn .FITTING :
304
- trainer .should_stop = True
305
341
self ._dumped_params ["num_training_batches" ] = trainer .num_training_batches
306
342
trainer .num_training_batches = 0
307
343
elif trainer .state .fn == TrainerFn .VALIDATING :
@@ -318,6 +354,7 @@ def post_early_exit(self, trainer):
318
354
if trainer .fast_dev_run :
319
355
return
320
356
357
+ # restore the state used to stop the respective loop
321
358
if trainer .state .fn == TrainerFn .FITTING :
322
359
trainer .num_training_batches = self ._dumped_params ["num_training_batches" ]
323
360
loop = trainer .fit_loop
@@ -331,7 +368,7 @@ def post_early_exit(self, trainer):
331
368
trainer .num_predict_batches = self ._dumped_params ["num_predict_batches" ]
332
369
loop = trainer .predict_loop
333
370
334
- loop .load_state_dict (self ._dumped_params ["loop_state_dict" ])
371
+ loop .load_state_dict (self ._dumped_params ["loop_state_dict" ], force_load_progress = True )
335
372
trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
336
373
337
374
def on_fit_start (self , trainer , pl_module ):
@@ -346,6 +383,8 @@ def on_validation_start(self, trainer, pl_module):
346
383
if trainer .sanity_checking or trainer .state .fn != TrainerFn .VALIDATING :
347
384
return
348
385
386
+ self .scale_batch_size (trainer , pl_module )
387
+
349
388
if self .early_exit :
350
389
self .pre_early_exit (trainer )
351
390
else :
0 commit comments