20
20
21
21
import os
22
22
import uuid
23
+ from copy import deepcopy
23
24
from typing import Optional , Tuple
24
25
25
26
from torch .utils .data .dataloader import DataLoader
38
39
39
40
40
41
class BatchSizeFinder (Callback ):
41
- def __init__ (self , mode : str = "power" , steps_per_trial = 3 , init_val = 2 , max_trials = 25 , batch_arg_name = "batch_size" ):
42
+ def __init__ (
43
+ self ,
44
+ mode : str = "power" ,
45
+ steps_per_trial = 3 ,
46
+ init_val = 2 ,
47
+ max_trials = 25 ,
48
+ batch_arg_name = "batch_size" ,
49
+ early_exit = False ,
50
+ ):
42
51
43
52
mode = mode .lower ()
44
53
if mode not in ("power" , "binsearch" ):
@@ -50,6 +59,7 @@ def __init__(self, mode: str = "power", steps_per_trial=3, init_val=2, max_trial
50
59
self .max_trials = max_trials
51
60
self .batch_arg_name = batch_arg_name
52
61
self .optimal_batch_size = init_val
62
+ self .early_exit = early_exit
53
63
54
64
def scale_batch_size (self , trainer , pl_module ):
55
65
if trainer .fast_dev_run :
@@ -90,7 +100,7 @@ def scale_batch_size(self, trainer, pl_module):
90
100
self ._reset_params (trainer )
91
101
92
102
# Save initial model, that is loaded after batch size is found
93
- save_path = os .path .join (trainer .default_root_dir , f"scale_batch_size_temp_model_{ uuid .uuid4 ()} .ckpt" )
103
+ save_path = os .path .join (trainer .default_root_dir , f". scale_batch_size_temp_model_{ uuid .uuid4 ()} .ckpt" )
94
104
trainer .save_checkpoint (save_path )
95
105
96
106
if trainer .progress_bar_callback :
@@ -112,6 +122,7 @@ def scale_batch_size(self, trainer, pl_module):
112
122
fs .rm (save_path )
113
123
114
124
self ._restore_params (trainer )
125
+
115
126
if trainer .progress_bar_callback :
116
127
trainer .progress_bar_callback .enable ()
117
128
@@ -182,7 +193,12 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
182
193
garbage_collection_cuda ()
183
194
high = new_size
184
195
midval = (high + low ) // 2
185
- new_size , _ = self ._adjust_batch_size (trainer , value = midval , desc = "failed" )
196
+ new_size , changed = self ._adjust_batch_size (trainer , value = midval , desc = "failed" )
197
+
198
+ if changed :
199
+ # Force the dataloaders to reset as the batch size has changed
200
+ self ._reset_dataloaders (trainer , pl_module )
201
+
186
202
if high - low <= 1 :
187
203
break
188
204
else :
@@ -193,14 +209,17 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
193
209
def _try_loop_run (self , trainer ):
194
210
if trainer .state .fn == TrainerFn .FITTING :
195
211
trainer .fit_loop .global_step = self ._dumped_params ["global_step" ]
196
- trainer .fit_loop .current_epoch = self ._dumped_params ["current_epoch" ]
197
- trainer .fit_loop . run ()
212
+ # trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
213
+ loop = trainer .fit_loop
198
214
elif trainer .state .fn == TrainerFn .VALIDATING :
199
- trainer .validate_loop . run ()
215
+ loop = trainer .validate_loop
200
216
elif trainer .state .fn == TrainerFn .TESTING :
201
- trainer .test_loop . run ()
217
+ loop = trainer .test_loop
202
218
elif trainer .state .fn == TrainerFn .PREDICTING :
203
- trainer .predict_loop .run ()
219
+ loop = trainer .predict_loop
220
+
221
+ loop .load_state_dict (deepcopy (self ._dumped_params ["loop_state_dict" ]))
222
+ loop .run ()
204
223
205
224
@staticmethod
206
225
def _reset_dataloaders (trainer , pl_module ):
@@ -216,57 +235,153 @@ def _reset_dataloaders(trainer, pl_module):
216
235
217
236
def _dump_params (self , trainer ):
218
237
self ._dumped_params = {
219
- "current_epoch" : trainer .current_epoch ,
220
- "global_step" : trainer .global_step ,
221
- "max_steps" : trainer .max_steps ,
238
+ # "current_epoch": trainer.current_epoch,
222
239
"logger" : trainer .logger ,
223
240
"callbacks" : trainer .callbacks ,
224
- "limit_train_batches" : trainer .limit_train_batches ,
225
- "limit_val_batches" : trainer .limit_val_batches ,
226
- "limit_test_batches" : trainer .limit_test_batches ,
227
- "limit_predict_batches" : trainer .limit_predict_batches ,
228
241
}
229
242
243
+ if trainer .state .fn == TrainerFn .FITTING :
244
+ loop = trainer .fit_loop
245
+ self ._dumped_params ["global_step" ] = trainer .global_step
246
+ self ._dumped_params ["max_steps" ] = trainer .max_steps
247
+ self ._dumped_params ["limit_val_batches" ] = trainer .limit_val_batches
248
+ elif trainer .state .fn == TrainerFn .VALIDATING :
249
+ loop = trainer .validate_loop
250
+ self ._dumped_params ["limit_val_batches" ] = trainer .limit_val_batches
251
+ elif trainer .state .fn == TrainerFn .TESTING :
252
+ loop = trainer .test_loop
253
+ self ._dumped_params ["limit_test_batches" ] = trainer .limit_test_batches
254
+ elif trainer .state .fn == TrainerFn .PREDICTING :
255
+ loop = trainer .predict_loop
256
+ self ._dumped_params ["limit_predict_batches" ] = trainer .limit_predict_batches
257
+
258
+ self ._dumped_params ["loop_state_dict" ] = deepcopy (loop .state_dict ())
259
+ if hasattr (loop , "verbose" ):
260
+ self ._dumped_params ["loop_verbose" ] = loop .verbose
261
+
230
262
def _reset_params (self , trainer ):
231
263
trainer .logger = DummyLogger () if trainer .logger is not None else None
232
264
trainer .callbacks = []
265
+
233
266
if trainer .state .fn == TrainerFn .FITTING :
234
267
trainer .limit_val_batches = self .steps_per_trial
235
268
trainer .fit_loop .max_steps = self .steps_per_trial
236
269
elif trainer .state .fn == TrainerFn .VALIDATING :
237
270
trainer .limit_val_batches = self .steps_per_trial
271
+ trainer .validate_loop .verbose = False
238
272
elif trainer .state .fn == TrainerFn .TESTING :
239
273
trainer .limit_test_batches = self .steps_per_trial
274
+ trainer .test_loop .verbose = False
240
275
elif trainer .state .fn == TrainerFn .PREDICTING :
241
276
trainer .limit_predict_batches = self .steps_per_trial
242
277
243
278
def _restore_params (self , trainer ):
244
- trainer .fit_loop .current_epoch = self ._dumped_params ["current_epoch" ]
245
- trainer .fit_loop .global_step = self ._dumped_params ["global_step" ]
246
- trainer .fit_loop .max_steps = self ._dumped_params ["max_steps" ]
247
279
trainer .logger = self ._dumped_params ["logger" ]
248
280
trainer .callbacks = self ._dumped_params ["callbacks" ]
249
- trainer .limit_train_batches = self ._dumped_params ["limit_train_batches" ]
250
- trainer .limit_val_batches = self ._dumped_params ["limit_val_batches" ]
251
- trainer .limit_test_batches = self ._dumped_params ["limit_test_batches" ]
252
- trainer .limit_predict_batches = self ._dumped_params ["limit_predict_batches" ]
253
281
254
- def on_train_epoch_start (self , trainer , pl_module ):
255
- self .scale_batch_size (trainer , pl_module )
282
+ if trainer .state .fn == TrainerFn .FITTING :
283
+ # trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
284
+ trainer .fit_loop .global_step = self ._dumped_params ["global_step" ]
285
+ loop = trainer .fit_loop
286
+ loop .max_steps = self ._dumped_params ["max_steps" ]
287
+ trainer .limit_val_batches = self ._dumped_params ["limit_val_batches" ]
288
+ elif trainer .state .fn == TrainerFn .VALIDATING :
289
+ loop = trainer .validate_loop
290
+ trainer .limit_val_batches = self ._dumped_params ["limit_val_batches" ]
291
+ elif trainer .state .fn == TrainerFn .TESTING :
292
+ loop = trainer .test_loop
293
+ trainer .limit_test_batches = self ._dumped_params ["limit_test_batches" ]
294
+ elif trainer .state .fn == TrainerFn .PREDICTING :
295
+ loop = trainer .predict_loop
296
+ trainer .limit_predict_batches = self ._dumped_params ["limit_predict_batches" ]
297
+
298
+ loop .load_state_dict (deepcopy (self ._dumped_params ["loop_state_dict" ]))
299
+ if "loop_verbose" in self ._dumped_params :
300
+ loop .verbose = self ._dumped_params ["loop_verbose" ]
301
+
302
+ def pre_early_exit (self , trainer ):
303
+ if trainer .state .fn == TrainerFn .FITTING :
304
+ trainer .should_stop = True
305
+ self ._dumped_params ["num_training_batches" ] = trainer .num_training_batches
306
+ trainer .num_training_batches = 0
307
+ elif trainer .state .fn == TrainerFn .VALIDATING :
308
+ self ._dumped_params ["num_val_batches" ] = trainer .num_val_batches
309
+ trainer .num_val_batches = [0 ]
310
+ elif trainer .state .fn == TrainerFn .TESTING :
311
+ self ._dumped_params ["num_test_batches" ] = trainer .num_test_batches
312
+ trainer .num_test_batches = [0 ]
313
+ elif trainer .state .fn == TrainerFn .PREDICTING :
314
+ self ._dumped_params ["num_predict_batches" ] = trainer .num_predict_batches
315
+ trainer .num_predict_batches = [0 ]
316
+
317
+ def post_early_exit (self , trainer ):
318
+ if trainer .state .fn == TrainerFn .FITTING :
319
+ trainer .num_training_batches = self ._dumped_params ["num_training_batches" ]
320
+ loop = trainer .fit_loop
321
+ if trainer .state .fn == TrainerFn .VALIDATING :
322
+ trainer .num_val_batches = self ._dumped_params ["num_val_batches" ]
323
+ loop = trainer .validate_loop
324
+ if trainer .state .fn == TrainerFn .TESTING :
325
+ trainer .num_test_batches = self ._dumped_params ["num_test_batches" ]
326
+ loop = trainer .test_loop
327
+ if trainer .state .fn == TrainerFn .PREDICTING :
328
+ trainer .num_predict_batches = self ._dumped_params ["num_predict_batches" ]
329
+ loop = trainer .predict_loop
330
+
331
+ loop .load_state_dict (self ._dumped_params ["loop_state_dict" ])
256
332
trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
257
333
258
- def on_validation_epoch_start (self , trainer , pl_module ):
259
- if not trainer .sanity_checking :
260
- self .scale_batch_size (trainer , pl_module )
334
+ def on_fit_start (self , trainer , pl_module ):
335
+ self .scale_batch_size (trainer , pl_module )
336
+
337
+ if self .early_exit :
338
+ self .pre_early_exit (trainer )
339
+ else :
340
+ trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
341
+
342
+ def on_validation_start (self , trainer , pl_module ):
343
+ if trainer .sanity_checking or trainer .state .fn != TrainerFn .VALIDATING :
344
+ return
345
+
346
+ if self .early_exit :
347
+ self .pre_early_exit (trainer )
348
+ else :
261
349
trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
262
350
263
- def on_test_epoch_start (self , trainer , pl_module ):
351
+ def on_test_start (self , trainer , pl_module ):
264
352
self .scale_batch_size (trainer , pl_module )
265
- trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
266
353
267
- def on_predict_epoch_start (self , trainer , pl_module ):
354
+ if self .early_exit :
355
+ self .pre_early_exit (trainer )
356
+ else :
357
+ trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
358
+
359
+ def on_predict_start (self , trainer , pl_module ):
268
360
self .scale_batch_size (trainer , pl_module )
269
- trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
361
+
362
+ if self .early_exit :
363
+ self .pre_early_exit (trainer )
364
+ else :
365
+ trainer .callbacks = [cb for cb in trainer .callbacks if not isinstance (cb , BatchSizeFinder )]
366
+
367
+ def on_fit_end (self , trainer , pl_module ):
368
+ if self .early_exit :
369
+ self .post_early_exit (trainer )
370
+
371
+ def on_validation_end (self , trainer , pl_module ):
372
+ if trainer .sanity_checking or trainer .state .fn != TrainerFn .VALIDATING :
373
+ return
374
+
375
+ if self .early_exit :
376
+ self .post_early_exit (trainer )
377
+
378
+ def on_test_end (self , trainer , pl_module ):
379
+ if self .early_exit :
380
+ self .post_early_exit (trainer )
381
+
382
+ def on_predict_end (self , trainer , pl_module ):
383
+ if self .early_exit :
384
+ self .post_early_exit (trainer )
270
385
271
386
def _adjust_batch_size (
272
387
self ,
@@ -295,19 +410,19 @@ def _adjust_batch_size(
295
410
if desc :
296
411
rank_zero_info (f"Batch size { batch_size } { desc } , trying batch size { new_size } " )
297
412
298
- # TODO improve this for CombinedLoader
413
+ # TODO improve this for CombinedLoader and multi dataloaders
299
414
if trainer .state .fn == TrainerFn .FITTING :
300
415
if not self ._is_valid_batch_size (new_size , trainer .train_dataloader , trainer ):
301
416
new_size = min (new_size , len (trainer .train_dataloader .dataset ))
302
417
if trainer .state .fn == TrainerFn .VALIDATING :
303
418
if not self ._is_valid_batch_size (new_size , trainer .val_dataloaders , trainer ):
304
- new_size = min (new_size , len (trainer .val_dataloaders .dataset ))
419
+ new_size = min (new_size , len (trainer .val_dataloaders [ 0 ] .dataset ))
305
420
if trainer .state .fn == TrainerFn .TESTING :
306
421
if not self ._is_valid_batch_size (new_size , trainer .test_dataloaders , trainer ):
307
- new_size = min (new_size , len (trainer .test_dataloaders .dataset ))
422
+ new_size = min (new_size , len (trainer .test_dataloaders [ 0 ] .dataset ))
308
423
if trainer .state .fn == TrainerFn .PREDICTING :
309
424
if not self ._is_valid_batch_size (new_size , trainer .predict_dataloaders , trainer ):
310
- new_size = min (new_size , len (trainer .predict_dataloaders .dataset ))
425
+ new_size = min (new_size , len (trainer .predict_dataloaders [ 0 ] .dataset ))
311
426
312
427
changed = new_size != batch_size
313
428
lightning_setattr (model , self .batch_arg_name , new_size )
@@ -316,4 +431,4 @@ def _adjust_batch_size(
316
431
@staticmethod
317
432
def _is_valid_batch_size (batch_size : int , dataloader : DataLoader , trainer : "pl.Trainer" ):
318
433
module = trainer .lightning_module or trainer .datamodule
319
- return not has_len_all_ranks (dataloader , trainer .training_type_plugin , module ) or batch_size <= len (dataloader )
434
+ return not has_len_all_ranks (dataloader , trainer .strategy , module ) or batch_size <= len (dataloader )
0 commit comments