|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +r""" |
| 15 | +BatchSizeFinder |
| 16 | +=============== |
| 17 | +
|
| 18 | +Finds optimal batch size |
| 19 | +""" |
| 20 | + |
| 21 | +import logging |
| 22 | +import os |
| 23 | +import uuid |
| 24 | +from typing import Optional, Tuple |
| 25 | + |
| 26 | +from torch.utils.data.dataloader import DataLoader |
| 27 | + |
| 28 | +import pytorch_lightning as pl |
| 29 | +from pytorch_lightning.callbacks.base import Callback |
| 30 | +from pytorch_lightning.loggers.base import DummyLogger |
| 31 | +from pytorch_lightning.trainer.states import TrainerFn |
| 32 | +from pytorch_lightning.utilities.cloud_io import get_filesystem |
| 33 | +from pytorch_lightning.utilities.data import has_len_all_ranks |
| 34 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 35 | +from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error |
| 36 | +from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr |
| 37 | +from pytorch_lightning.utilities.warnings import rank_zero_warn |
| 38 | + |
| 39 | +log = logging.getLogger(__name__) |
| 40 | + |
| 41 | + |
| 42 | +class BatchSizeFinder(Callback): |
| 43 | + def __init__(self, mode: str = "power", steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name="batch_size"): |
| 44 | + |
| 45 | + mode = mode.lower() |
| 46 | + if mode not in ("power", "binsearch"): |
| 47 | + raise MisconfigurationException("`mode` should be either 'power' or 'binsearch'") |
| 48 | + |
| 49 | + self.mode = mode |
| 50 | + self.steps_per_trial = steps_per_trial |
| 51 | + self.init_val = init_val |
| 52 | + self.max_trials = max_trials |
| 53 | + self.batch_arg_name = batch_arg_name |
| 54 | + self.optimal_batch_size = init_val |
| 55 | + |
| 56 | + def scale_batch_size(self, trainer, pl_module): |
| 57 | + if trainer.fast_dev_run: |
| 58 | + rank_zero_warn("Skiping batch size scaler since `fast_dev_run` is enabled.") |
| 59 | + return |
| 60 | + |
| 61 | + if not lightning_hasattr(pl_module, self.batch_arg_name): |
| 62 | + raise MisconfigurationException( |
| 63 | + f"Field {self.batch_arg_name} not found in both `model` and `model.hparams`" |
| 64 | + ) |
| 65 | + |
| 66 | + if not lightning_hasattr(pl_module, self.batch_arg_name): |
| 67 | + raise MisconfigurationException( |
| 68 | + f"Field {self.batch_arg_name} not found in both `model` and `model.hparams`" |
| 69 | + ) |
| 70 | + |
| 71 | + if ( |
| 72 | + hasattr(pl_module, self.batch_arg_name) |
| 73 | + and hasattr(pl_module, "hparams") |
| 74 | + and self.batch_arg_name in pl_module.hparams |
| 75 | + ): |
| 76 | + rank_zero_warn( |
| 77 | + f"Field `model.{self.batch_arg_name}` and `model.hparams.{self.batch_arg_name}` are mutually exclusive!" |
| 78 | + f" `model.{self.batch_arg_name}` will be used as the initial batch size for scaling." |
| 79 | + " If this is not the intended behavior, please remove either one." |
| 80 | + ) |
| 81 | + |
| 82 | + if not trainer._data_connector._train_dataloader_source.is_module(): |
| 83 | + raise MisconfigurationException( |
| 84 | + "The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`." |
| 85 | + " Please disable the feature or incorporate the dataloader into the model." |
| 86 | + ) |
| 87 | + |
| 88 | + # Arguments we adjust during the batch size finder, save for restoring |
| 89 | + self._dump_params(trainer) |
| 90 | + |
| 91 | + # Set to values that are required by the algorithm |
| 92 | + self._reset_params(trainer) |
| 93 | + |
| 94 | + # Save initial model, that is loaded after batch size is found |
| 95 | + save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt") |
| 96 | + trainer.save_checkpoint(save_path) |
| 97 | + |
| 98 | + if trainer.progress_bar_callback: |
| 99 | + trainer.progress_bar_callback.disable() |
| 100 | + |
| 101 | + new_size, _ = self._adjust_batch_size(trainer, value=self.init_val) |
| 102 | + |
| 103 | + if self.mode == "power": |
| 104 | + new_size = self._run_power_scaling(trainer, pl_module, new_size) |
| 105 | + elif self.mode == "binsearch": |
| 106 | + new_size = self._run_binary_scaling(trainer, pl_module, new_size) |
| 107 | + |
| 108 | + garbage_collection_cuda() |
| 109 | + |
| 110 | + if trainer.is_global_zero: |
| 111 | + trainer.checkpoint_connector.restore(save_path) |
| 112 | + fs = get_filesystem(save_path) |
| 113 | + if fs.exists(save_path): |
| 114 | + fs.rm(save_path) |
| 115 | + |
| 116 | + self._restore_params(trainer) |
| 117 | + if trainer.progress_bar_callback: |
| 118 | + trainer.progress_bar_callback.enable() |
| 119 | + |
| 120 | + print(f"new batch size: {new_size}") |
| 121 | + self.optimal_batch_size = new_size |
| 122 | + |
| 123 | + def _run_power_scaling(self, trainer, pl_module, new_size): |
| 124 | + """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" |
| 125 | + for _ in range(self.max_trials): |
| 126 | + garbage_collection_cuda() |
| 127 | + changed = False |
| 128 | + |
| 129 | + try: |
| 130 | + self._try_loop_run(trainer) |
| 131 | + new_size, changed = self._adjust_batch_size(trainer, factor=2.0, desc="succeeded") |
| 132 | + except RuntimeError as exception: |
| 133 | + if is_oom_error(exception): |
| 134 | + garbage_collection_cuda() |
| 135 | + new_size, _ = self._adjust_batch_size(trainer) |
| 136 | + break |
| 137 | + else: |
| 138 | + raise # some other error not memory related |
| 139 | + |
| 140 | + if changed: |
| 141 | + # Force the train dataloader to reset as the batch size has changed |
| 142 | + self._reset_dataloaders(trainer, pl_module) |
| 143 | + else: |
| 144 | + break |
| 145 | + |
| 146 | + return new_size |
| 147 | + |
| 148 | + def _run_binary_scaling(self, trainer, pl_module, new_size): |
| 149 | + """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is |
| 150 | + encountered. |
| 151 | +
|
| 152 | + Hereafter, the batch size is further refined using a binary search |
| 153 | + """ |
| 154 | + low = 1 |
| 155 | + high = None |
| 156 | + count = 0 |
| 157 | + while True: |
| 158 | + garbage_collection_cuda() |
| 159 | + trainer.fit_loop.global_step = 0 # reset after each try |
| 160 | + try: |
| 161 | + # Try fit |
| 162 | + self._try_loop_run(trainer) |
| 163 | + count += 1 |
| 164 | + if count > self.max_trials: |
| 165 | + break |
| 166 | + # Double in size |
| 167 | + low = new_size |
| 168 | + if high: |
| 169 | + if high - low <= 1: |
| 170 | + break |
| 171 | + midval = (high + low) // 2 |
| 172 | + new_size, changed = self._adjust_batch_size(trainer, value=midval, desc="succeeded") |
| 173 | + else: |
| 174 | + new_size, changed = self._adjust_batch_size(trainer, factor=2.0, desc="succeeded") |
| 175 | + |
| 176 | + if changed: |
| 177 | + # Force the train dataloader to reset as the batch size has changed |
| 178 | + self._reset_dataloaders(trainer, pl_module) |
| 179 | + else: |
| 180 | + break |
| 181 | + |
| 182 | + except RuntimeError as exception: |
| 183 | + # Only these errors should trigger an adjustment |
| 184 | + if is_oom_error(exception): |
| 185 | + # If we fail in power mode, half the size and return |
| 186 | + garbage_collection_cuda() |
| 187 | + high = new_size |
| 188 | + midval = (high + low) // 2 |
| 189 | + new_size, _ = self._adjust_batch_size(trainer, value=midval, desc="failed") |
| 190 | + if high - low <= 1: |
| 191 | + break |
| 192 | + else: |
| 193 | + raise # some other error not memory related |
| 194 | + |
| 195 | + return new_size |
| 196 | + |
| 197 | + def _try_loop_run(self, trainer): |
| 198 | + if trainer.state.fn == TrainerFn.FITTING: |
| 199 | + trainer.fit_loop.global_step = self._dumped_params["global_step"] |
| 200 | + trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"] |
| 201 | + trainer.fit_loop.run() |
| 202 | + elif trainer.state.fn == TrainerFn.VALIDATING: |
| 203 | + trainer.validate_loop.run() |
| 204 | + elif trainer.state.fn == TrainerFn.TESTING: |
| 205 | + trainer.test_loop.run() |
| 206 | + elif trainer.state.fn == TrainerFn.PREDICTING: |
| 207 | + trainer.predict_loop.run() |
| 208 | + |
| 209 | + @staticmethod |
| 210 | + def _reset_dataloaders(trainer, pl_module): |
| 211 | + if trainer.state.fn == TrainerFn.FITTING: |
| 212 | + trainer.reset_train_dataloader(pl_module) |
| 213 | + trainer.reset_val_dataloader(pl_module) |
| 214 | + elif trainer.state.fn == TrainerFn.VALIDATING: |
| 215 | + trainer.reset_val_dataloader(pl_module) |
| 216 | + elif trainer.state.fn == TrainerFn.TESTING: |
| 217 | + trainer.reset_test_dataloader(pl_module) |
| 218 | + elif trainer.state.fn == TrainerFn.PREDICTING: |
| 219 | + trainer.reset_predict_dataloader(pl_module) |
| 220 | + |
| 221 | + def _dump_params(self, trainer): |
| 222 | + self._dumped_params = { |
| 223 | + "current_epoch": trainer.current_epoch, |
| 224 | + "global_step": trainer.global_step, |
| 225 | + "max_steps": trainer.max_steps, |
| 226 | + "logger": trainer.logger, |
| 227 | + "callbacks": trainer.callbacks, |
| 228 | + "limit_train_batches": trainer.limit_train_batches, |
| 229 | + "limit_val_batches": trainer.limit_val_batches, |
| 230 | + "limit_test_batches": trainer.limit_test_batches, |
| 231 | + "limit_predict_batches": trainer.limit_predict_batches, |
| 232 | + } |
| 233 | + |
| 234 | + def _reset_params(self, trainer): |
| 235 | + trainer.logger = DummyLogger() if trainer.logger is not None else None |
| 236 | + trainer.callbacks = [] |
| 237 | + if trainer.state.fn == TrainerFn.FITTING: |
| 238 | + trainer.limit_val_batches = self.steps_per_trial |
| 239 | + trainer.fit_loop.max_steps = self.steps_per_trial |
| 240 | + elif trainer.state.fn == TrainerFn.VALIDATING: |
| 241 | + trainer.limit_val_batches = self.steps_per_trial |
| 242 | + elif trainer.state.fn == TrainerFn.TESTING: |
| 243 | + trainer.limit_test_batches = self.steps_per_trial |
| 244 | + elif trainer.state.fn == TrainerFn.PREDICTING: |
| 245 | + trainer.limit_predict_batches = self.steps_per_trial |
| 246 | + |
| 247 | + def _restore_params(self, trainer): |
| 248 | + trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"] |
| 249 | + trainer.fit_loop.global_step = self._dumped_params["global_step"] |
| 250 | + trainer.fit_loop.max_steps = self._dumped_params["max_steps"] |
| 251 | + trainer.logger = self._dumped_params["logger"] |
| 252 | + trainer.callbacks = self._dumped_params["callbacks"] |
| 253 | + trainer.limit_train_batches = self._dumped_params["limit_train_batches"] |
| 254 | + trainer.limit_val_batches = self._dumped_params["limit_val_batches"] |
| 255 | + trainer.limit_test_batches = self._dumped_params["limit_test_batches"] |
| 256 | + trainer.limit_predict_batches = self._dumped_params["limit_predict_batches"] |
| 257 | + |
| 258 | + def on_train_epoch_start(self, trainer, pl_module): |
| 259 | + self.scale_batch_size(trainer, pl_module) |
| 260 | + trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)] |
| 261 | + |
| 262 | + def on_validation_epoch_start(self, trainer, pl_module): |
| 263 | + if not trainer.sanity_checking: |
| 264 | + self.scale_batch_size(trainer, pl_module) |
| 265 | + trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)] |
| 266 | + |
| 267 | + def on_test_epoch_start(self, trainer, pl_module): |
| 268 | + self.scale_batch_size(trainer, pl_module) |
| 269 | + trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)] |
| 270 | + |
| 271 | + def on_predict_epoch_start(self, trainer, pl_module): |
| 272 | + self.scale_batch_size(trainer, pl_module) |
| 273 | + trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)] |
| 274 | + |
| 275 | + def _adjust_batch_size( |
| 276 | + self, |
| 277 | + trainer: "pl.Trainer", |
| 278 | + factor: float = 1.0, |
| 279 | + value: Optional[int] = None, |
| 280 | + desc: Optional[str] = None, |
| 281 | + ) -> Tuple[int, bool]: |
| 282 | + """Helper function for adjusting the batch size. |
| 283 | +
|
| 284 | + Args: |
| 285 | + trainer: instance of pytorch_lightning.Trainer |
| 286 | + factor: value which the old batch size is multiplied by to get the |
| 287 | + new batch size |
| 288 | + value: if a value is given, will override the batch size with this value. |
| 289 | + Note that the value of `factor` will not have an effect in this case |
| 290 | + desc: either `succeeded` or `failed`. Used purely for logging |
| 291 | +
|
| 292 | + Returns: |
| 293 | + The new batch size for the next trial and a bool that signals whether the |
| 294 | + new value is different than the previous batch size. |
| 295 | + """ |
| 296 | + model = trainer.lightning_module |
| 297 | + batch_size = lightning_getattr(model, self.batch_arg_name) |
| 298 | + new_size = value if value is not None else int(batch_size * factor) |
| 299 | + if desc: |
| 300 | + log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") |
| 301 | + |
| 302 | + # TODO improve this for CombinedLoader |
| 303 | + if trainer.state.fn == TrainerFn.FITTING: |
| 304 | + if not self._is_valid_batch_size(new_size, trainer.train_dataloader, trainer): |
| 305 | + new_size = min(new_size, len(trainer.train_dataloader.dataset)) |
| 306 | + if trainer.state.fn == TrainerFn.VALIDATING: |
| 307 | + if not self._is_valid_batch_size(new_size, trainer.val_dataloaders, trainer): |
| 308 | + new_size = min(new_size, len(trainer.val_dataloaders.dataset)) |
| 309 | + if trainer.state.fn == TrainerFn.TESTING: |
| 310 | + if not self._is_valid_batch_size(new_size, trainer.test_dataloaders, trainer): |
| 311 | + new_size = min(new_size, len(trainer.test_dataloaders.dataset)) |
| 312 | + if trainer.state.fn == TrainerFn.PREDICTING: |
| 313 | + if not self._is_valid_batch_size(new_size, trainer.predict_dataloaders, trainer): |
| 314 | + new_size = min(new_size, len(trainer.predict_dataloaders.dataset)) |
| 315 | + |
| 316 | + changed = new_size != batch_size |
| 317 | + lightning_setattr(model, self.batch_arg_name, new_size) |
| 318 | + return new_size, changed |
| 319 | + |
| 320 | + @staticmethod |
| 321 | + def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"): |
| 322 | + module = trainer.lightning_module or trainer.datamodule |
| 323 | + return not has_len_all_ranks(dataloader, trainer.training_type_plugin, module) or batch_size <= len(dataloader) |
0 commit comments