27
27
28
28
import pytorch_lightning as pl
29
29
from pytorch_lightning .callbacks .base import Callback
30
- from pytorch_lightning .loggers .base import DummyLogger
31
- from pytorch_lightning .trainer .states import TrainerFn
32
30
from pytorch_lightning .utilities .cloud_io import get_filesystem
33
- from pytorch_lightning .utilities .data import has_len_all_ranks
34
31
from pytorch_lightning .utilities .distributed import rank_zero_info
35
32
from pytorch_lightning .utilities .exceptions import _TunerExitException , MisconfigurationException
36
33
from pytorch_lightning .utilities .memory import garbage_collection_cuda , is_oom_error
@@ -42,11 +39,11 @@ class BatchSizeFinder(Callback):
42
39
def __init__ (
43
40
self ,
44
41
mode : str = "power" ,
45
- steps_per_trial = 3 ,
46
- init_val = 2 ,
47
- max_trials = 25 ,
48
- batch_arg_name = "batch_size" ,
49
- ):
42
+ steps_per_trial : int = 3 ,
43
+ init_val : int = 2 ,
44
+ max_trials : int = 25 ,
45
+ batch_arg_name : str = "batch_size" ,
46
+ ) -> None :
50
47
"""Callback try to find the largest batch size for a given model that does not give an out of memory (OOM)
51
48
error. It works with both training and evalation. All you need to do is add it as a callback inside Trainer
52
49
and call ``trainer.fit/validate/test/predict()``. Internally it calls the respective step function
@@ -90,7 +87,7 @@ def __init__(
90
87
91
88
self ._early_exit = False
92
89
93
- def scale_batch_size (self , trainer , pl_module ) :
90
+ def scale_batch_size (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
94
91
if trainer .fast_dev_run :
95
92
rank_zero_warn ("Skipping batch size scaler since `fast_dev_run` is enabled." )
96
93
return
@@ -165,7 +162,7 @@ def scale_batch_size(self, trainer, pl_module):
165
162
if self ._early_exit :
166
163
raise _TunerExitException ()
167
164
168
- def _run_power_scaling (self , trainer , pl_module , new_size ) :
165
+ def _run_power_scaling (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , new_size : int ) -> int :
169
166
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
170
167
for _ in range (self .max_trials ):
171
168
garbage_collection_cuda ()
@@ -189,7 +186,7 @@ def _run_power_scaling(self, trainer, pl_module, new_size):
189
186
190
187
return new_size
191
188
192
- def _run_binary_scaling (self , trainer , pl_module , new_size ) :
189
+ def _run_binary_scaling (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , new_size : int ) -> int :
193
190
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
194
191
encountered.
195
192
@@ -242,7 +239,9 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
242
239
243
240
return new_size
244
241
245
- def _try_loop_run (self , trainer ):
242
+ def _try_loop_run (self , trainer : "pl.Trainer" ) -> None :
243
+ from pytorch_lightning .trainer .states import TrainerFn
244
+
246
245
if trainer .state .fn == TrainerFn .FITTING :
247
246
trainer .fit_loop .global_step = self ._dumped_params ["global_step" ]
248
247
loop = trainer .fit_loop
@@ -257,7 +256,9 @@ def _try_loop_run(self, trainer):
257
256
loop .run ()
258
257
259
258
@staticmethod
260
- def _reset_dataloaders (trainer , pl_module ):
259
+ def _reset_dataloaders (trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
260
+ from pytorch_lightning .trainer .states import TrainerFn
261
+
261
262
if trainer .state .fn == TrainerFn .FITTING :
262
263
trainer .reset_train_dataloader (pl_module )
263
264
trainer .reset_val_dataloader (pl_module )
@@ -268,7 +269,9 @@ def _reset_dataloaders(trainer, pl_module):
268
269
elif trainer .state .fn == TrainerFn .PREDICTING :
269
270
trainer .reset_predict_dataloader (pl_module )
270
271
271
- def _dump_params (self , trainer ):
272
+ def _dump_params (self , trainer : "pl.Trainer" ) -> None :
273
+ from pytorch_lightning .trainer .states import TrainerFn
274
+
272
275
self ._dumped_params = {
273
276
"logger" : trainer .logger ,
274
277
"callbacks" : trainer .callbacks ,
@@ -293,7 +296,10 @@ def _dump_params(self, trainer):
293
296
if hasattr (loop , "verbose" ):
294
297
self ._dumped_params ["loop_verbose" ] = loop .verbose
295
298
296
- def _reset_params (self , trainer ):
299
+ def _reset_params (self , trainer : "pl.Trainer" ) -> None :
300
+ from pytorch_lightning .loggers .base import DummyLogger
301
+ from pytorch_lightning .trainer .states import TrainerFn
302
+
297
303
trainer .logger = DummyLogger () if trainer .logger is not None else None
298
304
trainer .callbacks = []
299
305
@@ -309,7 +315,9 @@ def _reset_params(self, trainer):
309
315
elif trainer .state .fn == TrainerFn .PREDICTING :
310
316
trainer .limit_predict_batches = self .steps_per_trial
311
317
312
- def _restore_params (self , trainer ):
318
+ def _restore_params (self , trainer : "pl.Trainer" ) -> None :
319
+ from pytorch_lightning .trainer .states import TrainerFn
320
+
313
321
trainer .logger = self ._dumped_params ["logger" ]
314
322
trainer .callbacks = self ._dumped_params ["callbacks" ]
315
323
@@ -332,19 +340,21 @@ def _restore_params(self, trainer):
332
340
if "loop_verbose" in self ._dumped_params :
333
341
loop .verbose = self ._dumped_params ["loop_verbose" ]
334
342
335
- def on_fit_start (self , trainer , pl_module ) :
343
+ def on_fit_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
336
344
self .scale_batch_size (trainer , pl_module )
337
345
338
- def on_validation_start (self , trainer , pl_module ):
346
+ def on_validation_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
347
+ from pytorch_lightning .trainer .states import TrainerFn
348
+
339
349
if trainer .sanity_checking or trainer .state .fn != TrainerFn .VALIDATING :
340
350
return
341
351
342
352
self .scale_batch_size (trainer , pl_module )
343
353
344
- def on_test_start (self , trainer , pl_module ) :
354
+ def on_test_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
345
355
self .scale_batch_size (trainer , pl_module )
346
356
347
- def on_predict_start (self , trainer , pl_module ) :
357
+ def on_predict_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
348
358
self .scale_batch_size (trainer , pl_module )
349
359
350
360
def _adjust_batch_size (
@@ -368,6 +378,8 @@ def _adjust_batch_size(
368
378
The new batch size for the next trial and a bool that signals whether the
369
379
new value is different than the previous batch size.
370
380
"""
381
+ from pytorch_lightning .trainer .states import TrainerFn
382
+
371
383
model = trainer .lightning_module
372
384
batch_size = lightning_getattr (model , self .batch_arg_name )
373
385
new_size = value if value is not None else int (batch_size * factor )
@@ -393,6 +405,8 @@ def _adjust_batch_size(
393
405
return new_size , changed
394
406
395
407
@staticmethod
396
- def _is_valid_batch_size (batch_size : int , dataloader : DataLoader , trainer : "pl.Trainer" ):
408
+ def _is_valid_batch_size (batch_size : int , dataloader : DataLoader , trainer : "pl.Trainer" ) -> bool :
409
+ from pytorch_lightning .utilities .data import has_len_all_ranks
410
+
397
411
module = trainer .lightning_module or trainer .datamodule
398
412
return not has_len_all_ranks (dataloader , trainer .strategy , module ) or batch_size <= len (dataloader )
0 commit comments