Skip to content

Commit dbfdb0a

Browse files
committed
add BatchSizeFinderCallback callback
1 parent fde326d commit dbfdb0a

File tree

7 files changed

+356
-8
lines changed

7 files changed

+356
-8
lines changed

pytorch_lightning/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.callbacks.base import Callback
15+
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
1516
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
1617
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1718
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
@@ -33,6 +34,7 @@
3334
__all__ = [
3435
"BackboneFinetuning",
3536
"BaseFinetuning",
37+
"BatchSizeFinder",
3638
"Callback",
3739
"DeviceStatsMonitor",
3840
"EarlyStopping",
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
r"""
15+
BatchSizeFinder
16+
===============
17+
18+
Finds optimal batch size
19+
"""
20+
21+
import logging
22+
import os
23+
import uuid
24+
from typing import Optional, Tuple
25+
26+
from torch.utils.data.dataloader import DataLoader
27+
28+
import pytorch_lightning as pl
29+
from pytorch_lightning.callbacks.base import Callback
30+
from pytorch_lightning.loggers.base import DummyLogger
31+
from pytorch_lightning.trainer.states import TrainerFn
32+
from pytorch_lightning.utilities.cloud_io import get_filesystem
33+
from pytorch_lightning.utilities.data import has_len_all_ranks
34+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
35+
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
36+
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
37+
from pytorch_lightning.utilities.warnings import rank_zero_warn
38+
39+
log = logging.getLogger(__name__)
40+
41+
42+
class BatchSizeFinder(Callback):
43+
def __init__(self, mode: str = "power", steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name="batch_size"):
44+
45+
mode = mode.lower()
46+
if mode not in ("power", "binsearch"):
47+
raise MisconfigurationException("`mode` should be either 'power' or 'binsearch'")
48+
49+
self.mode = mode
50+
self.steps_per_trial = steps_per_trial
51+
self.init_val = init_val
52+
self.max_trials = max_trials
53+
self.batch_arg_name = batch_arg_name
54+
self.optimal_batch_size = init_val
55+
56+
def scale_batch_size(self, trainer, pl_module):
57+
if trainer.fast_dev_run:
58+
rank_zero_warn("Skiping batch size scaler since `fast_dev_run` is enabled.")
59+
return
60+
61+
if not lightning_hasattr(pl_module, self.batch_arg_name):
62+
raise MisconfigurationException(
63+
f"Field {self.batch_arg_name} not found in both `model` and `model.hparams`"
64+
)
65+
66+
if not lightning_hasattr(pl_module, self.batch_arg_name):
67+
raise MisconfigurationException(
68+
f"Field {self.batch_arg_name} not found in both `model` and `model.hparams`"
69+
)
70+
71+
if (
72+
hasattr(pl_module, self.batch_arg_name)
73+
and hasattr(pl_module, "hparams")
74+
and self.batch_arg_name in pl_module.hparams
75+
):
76+
rank_zero_warn(
77+
f"Field `model.{self.batch_arg_name}` and `model.hparams.{self.batch_arg_name}` are mutually exclusive!"
78+
f" `model.{self.batch_arg_name}` will be used as the initial batch size for scaling."
79+
" If this is not the intended behavior, please remove either one."
80+
)
81+
82+
if not trainer._data_connector._train_dataloader_source.is_module():
83+
raise MisconfigurationException(
84+
"The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`."
85+
" Please disable the feature or incorporate the dataloader into the model."
86+
)
87+
88+
# Arguments we adjust during the batch size finder, save for restoring
89+
self._dump_params(trainer)
90+
91+
# Set to values that are required by the algorithm
92+
self._reset_params(trainer)
93+
94+
# Save initial model, that is loaded after batch size is found
95+
save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
96+
trainer.save_checkpoint(save_path)
97+
98+
if trainer.progress_bar_callback:
99+
trainer.progress_bar_callback.disable()
100+
101+
new_size, _ = self._adjust_batch_size(trainer, value=self.init_val)
102+
103+
if self.mode == "power":
104+
new_size = self._run_power_scaling(trainer, pl_module, new_size)
105+
elif self.mode == "binsearch":
106+
new_size = self._run_binary_scaling(trainer, pl_module, new_size)
107+
108+
garbage_collection_cuda()
109+
110+
if trainer.is_global_zero:
111+
trainer.checkpoint_connector.restore(save_path)
112+
fs = get_filesystem(save_path)
113+
if fs.exists(save_path):
114+
fs.rm(save_path)
115+
116+
self._restore_params(trainer)
117+
if trainer.progress_bar_callback:
118+
trainer.progress_bar_callback.enable()
119+
120+
print(f"new batch size: {new_size}")
121+
self.optimal_batch_size = new_size
122+
123+
def _run_power_scaling(self, trainer, pl_module, new_size):
124+
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
125+
for _ in range(self.max_trials):
126+
garbage_collection_cuda()
127+
changed = False
128+
129+
try:
130+
self._try_loop_run(trainer)
131+
new_size, changed = self._adjust_batch_size(trainer, factor=2.0, desc="succeeded")
132+
except RuntimeError as exception:
133+
if is_oom_error(exception):
134+
garbage_collection_cuda()
135+
new_size, _ = self._adjust_batch_size(trainer)
136+
break
137+
else:
138+
raise # some other error not memory related
139+
140+
if changed:
141+
# Force the train dataloader to reset as the batch size has changed
142+
self._reset_dataloaders(trainer, pl_module)
143+
else:
144+
break
145+
146+
return new_size
147+
148+
def _run_binary_scaling(self, trainer, pl_module, new_size):
149+
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
150+
encountered.
151+
152+
Hereafter, the batch size is further refined using a binary search
153+
"""
154+
low = 1
155+
high = None
156+
count = 0
157+
while True:
158+
garbage_collection_cuda()
159+
trainer.fit_loop.global_step = 0 # reset after each try
160+
try:
161+
# Try fit
162+
self._try_loop_run(trainer)
163+
count += 1
164+
if count > self.max_trials:
165+
break
166+
# Double in size
167+
low = new_size
168+
if high:
169+
if high - low <= 1:
170+
break
171+
midval = (high + low) // 2
172+
new_size, changed = self._adjust_batch_size(trainer, value=midval, desc="succeeded")
173+
else:
174+
new_size, changed = self._adjust_batch_size(trainer, factor=2.0, desc="succeeded")
175+
176+
if changed:
177+
# Force the train dataloader to reset as the batch size has changed
178+
self._reset_dataloaders(trainer, pl_module)
179+
else:
180+
break
181+
182+
except RuntimeError as exception:
183+
# Only these errors should trigger an adjustment
184+
if is_oom_error(exception):
185+
# If we fail in power mode, half the size and return
186+
garbage_collection_cuda()
187+
high = new_size
188+
midval = (high + low) // 2
189+
new_size, _ = self._adjust_batch_size(trainer, value=midval, desc="failed")
190+
if high - low <= 1:
191+
break
192+
else:
193+
raise # some other error not memory related
194+
195+
return new_size
196+
197+
def _try_loop_run(self, trainer):
198+
if trainer.state.fn == TrainerFn.FITTING:
199+
trainer.fit_loop.global_step = self._dumped_params["global_step"]
200+
trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
201+
trainer.fit_loop.run()
202+
elif trainer.state.fn == TrainerFn.VALIDATING:
203+
trainer.validate_loop.run()
204+
elif trainer.state.fn == TrainerFn.TESTING:
205+
trainer.test_loop.run()
206+
elif trainer.state.fn == TrainerFn.PREDICTING:
207+
trainer.predict_loop.run()
208+
209+
@staticmethod
210+
def _reset_dataloaders(trainer, pl_module):
211+
if trainer.state.fn == TrainerFn.FITTING:
212+
trainer.reset_train_dataloader(pl_module)
213+
trainer.reset_val_dataloader(pl_module)
214+
elif trainer.state.fn == TrainerFn.VALIDATING:
215+
trainer.reset_val_dataloader(pl_module)
216+
elif trainer.state.fn == TrainerFn.TESTING:
217+
trainer.reset_test_dataloader(pl_module)
218+
elif trainer.state.fn == TrainerFn.PREDICTING:
219+
trainer.reset_predict_dataloader(pl_module)
220+
221+
def _dump_params(self, trainer):
222+
self._dumped_params = {
223+
"current_epoch": trainer.current_epoch,
224+
"global_step": trainer.global_step,
225+
"max_steps": trainer.max_steps,
226+
"logger": trainer.logger,
227+
"callbacks": trainer.callbacks,
228+
"limit_train_batches": trainer.limit_train_batches,
229+
"limit_val_batches": trainer.limit_val_batches,
230+
"limit_test_batches": trainer.limit_test_batches,
231+
"limit_predict_batches": trainer.limit_predict_batches,
232+
}
233+
234+
def _reset_params(self, trainer):
235+
trainer.logger = DummyLogger() if trainer.logger is not None else None
236+
trainer.callbacks = []
237+
if trainer.state.fn == TrainerFn.FITTING:
238+
trainer.limit_val_batches = self.steps_per_trial
239+
trainer.fit_loop.max_steps = self.steps_per_trial
240+
elif trainer.state.fn == TrainerFn.VALIDATING:
241+
trainer.limit_val_batches = self.steps_per_trial
242+
elif trainer.state.fn == TrainerFn.TESTING:
243+
trainer.limit_test_batches = self.steps_per_trial
244+
elif trainer.state.fn == TrainerFn.PREDICTING:
245+
trainer.limit_predict_batches = self.steps_per_trial
246+
247+
def _restore_params(self, trainer):
248+
trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
249+
trainer.fit_loop.global_step = self._dumped_params["global_step"]
250+
trainer.fit_loop.max_steps = self._dumped_params["max_steps"]
251+
trainer.logger = self._dumped_params["logger"]
252+
trainer.callbacks = self._dumped_params["callbacks"]
253+
trainer.limit_train_batches = self._dumped_params["limit_train_batches"]
254+
trainer.limit_val_batches = self._dumped_params["limit_val_batches"]
255+
trainer.limit_test_batches = self._dumped_params["limit_test_batches"]
256+
trainer.limit_predict_batches = self._dumped_params["limit_predict_batches"]
257+
258+
def on_train_epoch_start(self, trainer, pl_module):
259+
self.scale_batch_size(trainer, pl_module)
260+
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
261+
262+
def on_validation_epoch_start(self, trainer, pl_module):
263+
if not trainer.sanity_checking:
264+
self.scale_batch_size(trainer, pl_module)
265+
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
266+
267+
def on_test_epoch_start(self, trainer, pl_module):
268+
self.scale_batch_size(trainer, pl_module)
269+
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
270+
271+
def on_predict_epoch_start(self, trainer, pl_module):
272+
self.scale_batch_size(trainer, pl_module)
273+
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
274+
275+
def _adjust_batch_size(
276+
self,
277+
trainer: "pl.Trainer",
278+
factor: float = 1.0,
279+
value: Optional[int] = None,
280+
desc: Optional[str] = None,
281+
) -> Tuple[int, bool]:
282+
"""Helper function for adjusting the batch size.
283+
284+
Args:
285+
trainer: instance of pytorch_lightning.Trainer
286+
factor: value which the old batch size is multiplied by to get the
287+
new batch size
288+
value: if a value is given, will override the batch size with this value.
289+
Note that the value of `factor` will not have an effect in this case
290+
desc: either `succeeded` or `failed`. Used purely for logging
291+
292+
Returns:
293+
The new batch size for the next trial and a bool that signals whether the
294+
new value is different than the previous batch size.
295+
"""
296+
model = trainer.lightning_module
297+
batch_size = lightning_getattr(model, self.batch_arg_name)
298+
new_size = value if value is not None else int(batch_size * factor)
299+
if desc:
300+
log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
301+
302+
# TODO improve this for CombinedLoader
303+
if trainer.state.fn == TrainerFn.FITTING:
304+
if not self._is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
305+
new_size = min(new_size, len(trainer.train_dataloader.dataset))
306+
if trainer.state.fn == TrainerFn.VALIDATING:
307+
if not self._is_valid_batch_size(new_size, trainer.val_dataloaders, trainer):
308+
new_size = min(new_size, len(trainer.val_dataloaders.dataset))
309+
if trainer.state.fn == TrainerFn.TESTING:
310+
if not self._is_valid_batch_size(new_size, trainer.test_dataloaders, trainer):
311+
new_size = min(new_size, len(trainer.test_dataloaders.dataset))
312+
if trainer.state.fn == TrainerFn.PREDICTING:
313+
if not self._is_valid_batch_size(new_size, trainer.predict_dataloaders, trainer):
314+
new_size = min(new_size, len(trainer.predict_dataloaders.dataset))
315+
316+
changed = new_size != batch_size
317+
lightning_setattr(model, self.batch_arg_name, new_size)
318+
return new_size, changed
319+
320+
@staticmethod
321+
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
322+
module = trainer.lightning_module or trainer.datamodule
323+
return not has_len_all_ranks(dataloader, trainer.training_type_plugin, module) or batch_size <= len(dataloader)

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
RichProgressBar,
2525
TQDMProgressBar,
2626
)
27+
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
2728
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
2829
from pytorch_lightning.callbacks.timer import Timer
2930
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
@@ -302,4 +303,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
302303
"""
303304
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
304305
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
305-
return not_checkpoints + checkpoints
306+
callbacks = not_checkpoints + checkpoints
307+
batch_size_finder_callback = [c for c in callbacks if isinstance(c, BatchSizeFinder)]
308+
other_callbacks = [c for c in callbacks if not isinstance(c, BatchSizeFinder)]
309+
return batch_size_finder_callback + other_callbacks

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,9 +1049,7 @@ def tune(
10491049
"""
10501050
Trainer._log_api_event("tune")
10511051

1052-
self.state.fn = TrainerFn.TUNING
10531052
self.state.status = TrainerStatus.RUNNING
1054-
self.tuning = True
10551053

10561054
# if a datamodule comes in as the second arg, then fix it for the user
10571055
if isinstance(train_dataloaders, LightningDataModule):
@@ -1068,7 +1066,14 @@ def tune(
10681066
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
10691067
)
10701068

1071-
result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs)
1069+
result = self.tuner._tune(
1070+
model,
1071+
train_dataloaders,
1072+
val_dataloaders,
1073+
datamodule,
1074+
scale_batch_size_kwargs=scale_batch_size_kwargs,
1075+
lr_find_kwargs=lr_find_kwargs,
1076+
)
10721077

10731078
assert self.state.stopped
10741079
self.tuning = False

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def scale_batch_size(
4040
batch_arg_name: str = "batch_size",
4141
) -> Optional[int]:
4242
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
43+
raise MisconfigurationException("this is gone")
44+
4345
if trainer.fast_dev_run:
4446
rank_zero_warn("Skipping batch size scaler since fast_dev_run is enabled.", UserWarning)
4547
return

0 commit comments

Comments
 (0)