18
18
Finds optimal batch size
19
19
"""
20
20
21
- import logging
22
21
import os
23
22
import uuid
24
23
from typing import Optional , Tuple
31
30
from pytorch_lightning .trainer .states import TrainerFn
32
31
from pytorch_lightning .utilities .cloud_io import get_filesystem
33
32
from pytorch_lightning .utilities .data import has_len_all_ranks
33
+ from pytorch_lightning .utilities .distributed import rank_zero_info
34
34
from pytorch_lightning .utilities .exceptions import MisconfigurationException
35
35
from pytorch_lightning .utilities .memory import garbage_collection_cuda , is_oom_error
36
36
from pytorch_lightning .utilities .parsing import lightning_getattr , lightning_hasattr , lightning_setattr
37
37
from pytorch_lightning .utilities .warnings import rank_zero_warn
38
38
39
- log = logging .getLogger (__name__ )
40
-
41
39
42
40
class BatchSizeFinder (Callback ):
43
41
def __init__ (self , mode : str = "power" , steps_per_trial = 3 , init_val = 2 , max_trials = 25 , batch_arg_name = "batch_size" ):
@@ -124,11 +122,16 @@ def _run_power_scaling(self, trainer, pl_module, new_size):
124
122
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
125
123
for _ in range (self .max_trials ):
126
124
garbage_collection_cuda ()
127
- changed = False
128
125
129
126
try :
130
127
self ._try_loop_run (trainer )
131
128
new_size , changed = self ._adjust_batch_size (trainer , factor = 2.0 , desc = "succeeded" )
129
+
130
+ if changed :
131
+ # Force the dataloaders to reset as the batch size has changed
132
+ self ._reset_dataloaders (trainer , pl_module )
133
+ else :
134
+ break
132
135
except RuntimeError as exception :
133
136
if is_oom_error (exception ):
134
137
garbage_collection_cuda ()
@@ -137,12 +140,6 @@ def _run_power_scaling(self, trainer, pl_module, new_size):
137
140
else :
138
141
raise # some other error not memory related
139
142
140
- if changed :
141
- # Force the train dataloader to reset as the batch size has changed
142
- self ._reset_dataloaders (trainer , pl_module )
143
- else :
144
- break
145
-
146
143
return new_size
147
144
148
145
def _run_binary_scaling (self , trainer , pl_module , new_size ):
@@ -156,7 +153,6 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
156
153
count = 0
157
154
while True :
158
155
garbage_collection_cuda ()
159
- trainer .fit_loop .global_step = 0 # reset after each try
160
156
try :
161
157
# Try fit
162
158
self ._try_loop_run (trainer )
@@ -174,7 +170,7 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
174
170
new_size , changed = self ._adjust_batch_size (trainer , factor = 2.0 , desc = "succeeded" )
175
171
176
172
if changed :
177
- # Force the train dataloader to reset as the batch size has changed
173
+ # Force the dataloaders to reset as the batch size has changed
178
174
self ._reset_dataloaders (trainer , pl_module )
179
175
else :
180
176
break
@@ -287,7 +283,7 @@ def _adjust_batch_size(
287
283
new batch size
288
284
value: if a value is given, will override the batch size with this value.
289
285
Note that the value of `factor` will not have an effect in this case
290
- desc: either `succeeded` or `failed`. Used purely for logging
286
+ desc: either ``" succeeded"`` or ``" failed"` `. Used purely for logging
291
287
292
288
Returns:
293
289
The new batch size for the next trial and a bool that signals whether the
@@ -297,7 +293,7 @@ def _adjust_batch_size(
297
293
batch_size = lightning_getattr (model , self .batch_arg_name )
298
294
new_size = value if value is not None else int (batch_size * factor )
299
295
if desc :
300
- log . info (f"Batch size { batch_size } { desc } , trying batch size { new_size } " )
296
+ rank_zero_info (f"Batch size { batch_size } { desc } , trying batch size { new_size } " )
301
297
302
298
# TODO improve this for CombinedLoader
303
299
if trainer .state .fn == TrainerFn .FITTING :
0 commit comments