Skip to content

Commit 5ee48cc

Browse files
committed
restore loops and intergrate early exit
1 parent cb7077a commit 5ee48cc

File tree

4 files changed

+164
-46
lines changed

4 files changed

+164
-46
lines changed

pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 151 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import os
2222
import uuid
23+
from copy import deepcopy
2324
from typing import Optional, Tuple
2425

2526
from torch.utils.data.dataloader import DataLoader
@@ -38,7 +39,15 @@
3839

3940

4041
class BatchSizeFinder(Callback):
41-
def __init__(self, mode: str = "power", steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name="batch_size"):
42+
def __init__(
43+
self,
44+
mode: str = "power",
45+
steps_per_trial=3,
46+
init_val=2,
47+
max_trials=25,
48+
batch_arg_name="batch_size",
49+
early_exit=False,
50+
):
4251

4352
mode = mode.lower()
4453
if mode not in ("power", "binsearch"):
@@ -50,6 +59,7 @@ def __init__(self, mode: str = "power", steps_per_trial=3, init_val=2, max_trial
5059
self.max_trials = max_trials
5160
self.batch_arg_name = batch_arg_name
5261
self.optimal_batch_size = init_val
62+
self.early_exit = early_exit
5363

5464
def scale_batch_size(self, trainer, pl_module):
5565
if trainer.fast_dev_run:
@@ -90,7 +100,7 @@ def scale_batch_size(self, trainer, pl_module):
90100
self._reset_params(trainer)
91101

92102
# Save initial model, that is loaded after batch size is found
93-
save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
103+
save_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
94104
trainer.save_checkpoint(save_path)
95105

96106
if trainer.progress_bar_callback:
@@ -112,6 +122,7 @@ def scale_batch_size(self, trainer, pl_module):
112122
fs.rm(save_path)
113123

114124
self._restore_params(trainer)
125+
115126
if trainer.progress_bar_callback:
116127
trainer.progress_bar_callback.enable()
117128

@@ -182,7 +193,12 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
182193
garbage_collection_cuda()
183194
high = new_size
184195
midval = (high + low) // 2
185-
new_size, _ = self._adjust_batch_size(trainer, value=midval, desc="failed")
196+
new_size, changed = self._adjust_batch_size(trainer, value=midval, desc="failed")
197+
198+
if changed:
199+
# Force the dataloaders to reset as the batch size has changed
200+
self._reset_dataloaders(trainer, pl_module)
201+
186202
if high - low <= 1:
187203
break
188204
else:
@@ -193,14 +209,17 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
193209
def _try_loop_run(self, trainer):
194210
if trainer.state.fn == TrainerFn.FITTING:
195211
trainer.fit_loop.global_step = self._dumped_params["global_step"]
196-
trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
197-
trainer.fit_loop.run()
212+
# trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
213+
loop = trainer.fit_loop
198214
elif trainer.state.fn == TrainerFn.VALIDATING:
199-
trainer.validate_loop.run()
215+
loop = trainer.validate_loop
200216
elif trainer.state.fn == TrainerFn.TESTING:
201-
trainer.test_loop.run()
217+
loop = trainer.test_loop
202218
elif trainer.state.fn == TrainerFn.PREDICTING:
203-
trainer.predict_loop.run()
219+
loop = trainer.predict_loop
220+
221+
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]))
222+
loop.run()
204223

205224
@staticmethod
206225
def _reset_dataloaders(trainer, pl_module):
@@ -216,57 +235,153 @@ def _reset_dataloaders(trainer, pl_module):
216235

217236
def _dump_params(self, trainer):
218237
self._dumped_params = {
219-
"current_epoch": trainer.current_epoch,
220-
"global_step": trainer.global_step,
221-
"max_steps": trainer.max_steps,
238+
# "current_epoch": trainer.current_epoch,
222239
"logger": trainer.logger,
223240
"callbacks": trainer.callbacks,
224-
"limit_train_batches": trainer.limit_train_batches,
225-
"limit_val_batches": trainer.limit_val_batches,
226-
"limit_test_batches": trainer.limit_test_batches,
227-
"limit_predict_batches": trainer.limit_predict_batches,
228241
}
229242

243+
if trainer.state.fn == TrainerFn.FITTING:
244+
loop = trainer.fit_loop
245+
self._dumped_params["global_step"] = trainer.global_step
246+
self._dumped_params["max_steps"] = trainer.max_steps
247+
self._dumped_params["limit_val_batches"] = trainer.limit_val_batches
248+
elif trainer.state.fn == TrainerFn.VALIDATING:
249+
loop = trainer.validate_loop
250+
self._dumped_params["limit_val_batches"] = trainer.limit_val_batches
251+
elif trainer.state.fn == TrainerFn.TESTING:
252+
loop = trainer.test_loop
253+
self._dumped_params["limit_test_batches"] = trainer.limit_test_batches
254+
elif trainer.state.fn == TrainerFn.PREDICTING:
255+
loop = trainer.predict_loop
256+
self._dumped_params["limit_predict_batches"] = trainer.limit_predict_batches
257+
258+
self._dumped_params["loop_state_dict"] = deepcopy(loop.state_dict())
259+
if hasattr(loop, "verbose"):
260+
self._dumped_params["loop_verbose"] = loop.verbose
261+
230262
def _reset_params(self, trainer):
231263
trainer.logger = DummyLogger() if trainer.logger is not None else None
232264
trainer.callbacks = []
265+
233266
if trainer.state.fn == TrainerFn.FITTING:
234267
trainer.limit_val_batches = self.steps_per_trial
235268
trainer.fit_loop.max_steps = self.steps_per_trial
236269
elif trainer.state.fn == TrainerFn.VALIDATING:
237270
trainer.limit_val_batches = self.steps_per_trial
271+
trainer.validate_loop.verbose = False
238272
elif trainer.state.fn == TrainerFn.TESTING:
239273
trainer.limit_test_batches = self.steps_per_trial
274+
trainer.test_loop.verbose = False
240275
elif trainer.state.fn == TrainerFn.PREDICTING:
241276
trainer.limit_predict_batches = self.steps_per_trial
242277

243278
def _restore_params(self, trainer):
244-
trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
245-
trainer.fit_loop.global_step = self._dumped_params["global_step"]
246-
trainer.fit_loop.max_steps = self._dumped_params["max_steps"]
247279
trainer.logger = self._dumped_params["logger"]
248280
trainer.callbacks = self._dumped_params["callbacks"]
249-
trainer.limit_train_batches = self._dumped_params["limit_train_batches"]
250-
trainer.limit_val_batches = self._dumped_params["limit_val_batches"]
251-
trainer.limit_test_batches = self._dumped_params["limit_test_batches"]
252-
trainer.limit_predict_batches = self._dumped_params["limit_predict_batches"]
253281

254-
def on_train_epoch_start(self, trainer, pl_module):
255-
self.scale_batch_size(trainer, pl_module)
282+
if trainer.state.fn == TrainerFn.FITTING:
283+
# trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
284+
trainer.fit_loop.global_step = self._dumped_params["global_step"]
285+
loop = trainer.fit_loop
286+
loop.max_steps = self._dumped_params["max_steps"]
287+
trainer.limit_val_batches = self._dumped_params["limit_val_batches"]
288+
elif trainer.state.fn == TrainerFn.VALIDATING:
289+
loop = trainer.validate_loop
290+
trainer.limit_val_batches = self._dumped_params["limit_val_batches"]
291+
elif trainer.state.fn == TrainerFn.TESTING:
292+
loop = trainer.test_loop
293+
trainer.limit_test_batches = self._dumped_params["limit_test_batches"]
294+
elif trainer.state.fn == TrainerFn.PREDICTING:
295+
loop = trainer.predict_loop
296+
trainer.limit_predict_batches = self._dumped_params["limit_predict_batches"]
297+
298+
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]))
299+
if "loop_verbose" in self._dumped_params:
300+
loop.verbose = self._dumped_params["loop_verbose"]
301+
302+
def pre_early_exit(self, trainer):
303+
if trainer.state.fn == TrainerFn.FITTING:
304+
trainer.should_stop = True
305+
self._dumped_params["num_training_batches"] = trainer.num_training_batches
306+
trainer.num_training_batches = 0
307+
elif trainer.state.fn == TrainerFn.VALIDATING:
308+
self._dumped_params["num_val_batches"] = trainer.num_val_batches
309+
trainer.num_val_batches = [0]
310+
elif trainer.state.fn == TrainerFn.TESTING:
311+
self._dumped_params["num_test_batches"] = trainer.num_test_batches
312+
trainer.num_test_batches = [0]
313+
elif trainer.state.fn == TrainerFn.PREDICTING:
314+
self._dumped_params["num_predict_batches"] = trainer.num_predict_batches
315+
trainer.num_predict_batches = [0]
316+
317+
def post_early_exit(self, trainer):
318+
if trainer.state.fn == TrainerFn.FITTING:
319+
trainer.num_training_batches = self._dumped_params["num_training_batches"]
320+
loop = trainer.fit_loop
321+
if trainer.state.fn == TrainerFn.VALIDATING:
322+
trainer.num_val_batches = self._dumped_params["num_val_batches"]
323+
loop = trainer.validate_loop
324+
if trainer.state.fn == TrainerFn.TESTING:
325+
trainer.num_test_batches = self._dumped_params["num_test_batches"]
326+
loop = trainer.test_loop
327+
if trainer.state.fn == TrainerFn.PREDICTING:
328+
trainer.num_predict_batches = self._dumped_params["num_predict_batches"]
329+
loop = trainer.predict_loop
330+
331+
loop.load_state_dict(self._dumped_params["loop_state_dict"])
256332
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
257333

258-
def on_validation_epoch_start(self, trainer, pl_module):
259-
if not trainer.sanity_checking:
260-
self.scale_batch_size(trainer, pl_module)
334+
def on_fit_start(self, trainer, pl_module):
335+
self.scale_batch_size(trainer, pl_module)
336+
337+
if self.early_exit:
338+
self.pre_early_exit(trainer)
339+
else:
340+
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
341+
342+
def on_validation_start(self, trainer, pl_module):
343+
if trainer.sanity_checking or trainer.state.fn != TrainerFn.VALIDATING:
344+
return
345+
346+
if self.early_exit:
347+
self.pre_early_exit(trainer)
348+
else:
261349
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
262350

263-
def on_test_epoch_start(self, trainer, pl_module):
351+
def on_test_start(self, trainer, pl_module):
264352
self.scale_batch_size(trainer, pl_module)
265-
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
266353

267-
def on_predict_epoch_start(self, trainer, pl_module):
354+
if self.early_exit:
355+
self.pre_early_exit(trainer)
356+
else:
357+
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
358+
359+
def on_predict_start(self, trainer, pl_module):
268360
self.scale_batch_size(trainer, pl_module)
269-
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
361+
362+
if self.early_exit:
363+
self.pre_early_exit(trainer)
364+
else:
365+
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
366+
367+
def on_fit_end(self, trainer, pl_module):
368+
if self.early_exit:
369+
self.post_early_exit(trainer)
370+
371+
def on_validation_end(self, trainer, pl_module):
372+
if trainer.sanity_checking or trainer.state.fn != TrainerFn.VALIDATING:
373+
return
374+
375+
if self.early_exit:
376+
self.post_early_exit(trainer)
377+
378+
def on_test_end(self, trainer, pl_module):
379+
if self.early_exit:
380+
self.post_early_exit(trainer)
381+
382+
def on_predict_end(self, trainer, pl_module):
383+
if self.early_exit:
384+
self.post_early_exit(trainer)
270385

271386
def _adjust_batch_size(
272387
self,
@@ -295,19 +410,19 @@ def _adjust_batch_size(
295410
if desc:
296411
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
297412

298-
# TODO improve this for CombinedLoader
413+
# TODO improve this for CombinedLoader and multi dataloaders
299414
if trainer.state.fn == TrainerFn.FITTING:
300415
if not self._is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
301416
new_size = min(new_size, len(trainer.train_dataloader.dataset))
302417
if trainer.state.fn == TrainerFn.VALIDATING:
303418
if not self._is_valid_batch_size(new_size, trainer.val_dataloaders, trainer):
304-
new_size = min(new_size, len(trainer.val_dataloaders.dataset))
419+
new_size = min(new_size, len(trainer.val_dataloaders[0].dataset))
305420
if trainer.state.fn == TrainerFn.TESTING:
306421
if not self._is_valid_batch_size(new_size, trainer.test_dataloaders, trainer):
307-
new_size = min(new_size, len(trainer.test_dataloaders.dataset))
422+
new_size = min(new_size, len(trainer.test_dataloaders[0].dataset))
308423
if trainer.state.fn == TrainerFn.PREDICTING:
309424
if not self._is_valid_batch_size(new_size, trainer.predict_dataloaders, trainer):
310-
new_size = min(new_size, len(trainer.predict_dataloaders.dataset))
425+
new_size = min(new_size, len(trainer.predict_dataloaders[0].dataset))
311426

312427
changed = new_size != batch_size
313428
lightning_setattr(model, self.batch_arg_name, new_size)
@@ -316,4 +431,4 @@ def _adjust_batch_size(
316431
@staticmethod
317432
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
318433
module = trainer.lightning_module or trainer.datamodule
319-
return not has_len_all_ranks(dataloader, trainer.training_type_plugin, module) or batch_size <= len(dataloader)
434+
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)

pytorch_lightning/loops/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.trainer.progress import BaseProgress
2424
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
26+
from pytorch_lightning.utilities.imports import _fault_tolerant_training
2627

2728
T = TypeVar("T") # the output type of `run`
2829

@@ -335,4 +336,5 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
335336
v.reset(metrics=False)
336337

337338
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
338-
self.restarting = True
339+
if _fault_tolerant_training():
340+
self.restarting = True

pytorch_lightning/tuner/tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def lr_find(
204204

205205
def fit(self, model, train_dataloaders, val_dataloaders, datamodule, **batch_size_scale_kwargs):
206206
self.trainer.state.fn = None
207-
batch_size_finder = BatchSizeFinder(**batch_size_scale_kwargs)
207+
batch_size_finder = BatchSizeFinder(**batch_size_scale_kwargs, early_exit=True)
208208
self.trainer.callbacks = [batch_size_finder] + self.trainer.callbacks
209209
self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
210210
return batch_size_finder.optimal_batch_size

tests/tuner/test_scale_batch_size.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_model_reset_correctly(tmpdir):
9292
torch.eq(before_state_dict[key], after_state_dict[key])
9393
), "Model was not reset correctly after scaling batch size"
9494

95-
assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model"))
95+
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size_temp_model"))
9696

9797

9898
def test_trainer_reset_correctly(tmpdir):
@@ -105,18 +105,19 @@ def test_trainer_reset_correctly(tmpdir):
105105
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
106106

107107
changed_attributes = [
108-
"callbacks",
109-
"checkpoint_callback",
110-
"current_epoch",
111-
"limit_train_batches",
112-
"logger",
113-
"max_steps",
114108
"global_step",
109+
"limit_val_batches",
110+
"max_steps",
111+
"logger",
112+
"callbacks",
115113
]
116114
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
117-
trainer.tuner.scale_batch_size(model, max_trials=5)
115+
expected_loop_state_dict = trainer.fit_loop.state_dict()
116+
trainer.tuner.scale_batch_size(model, max_trials=64)
118117
actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
118+
actual_loop_state_dict = trainer.fit_loop.state_dict()
119119

120+
assert expected_loop_state_dict == actual_loop_state_dict
120121
assert actual == expected
121122

122123

0 commit comments

Comments
 (0)