Skip to content

Commit 671607c

Browse files
committed
enable fast_dev_run test
1 parent 5ee48cc commit 671607c

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363

6464
def scale_batch_size(self, trainer, pl_module):
6565
if trainer.fast_dev_run:
66-
rank_zero_warn("Skiping batch size scaler since `fast_dev_run` is enabled.")
66+
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
6767
return
6868

6969
if not lightning_hasattr(pl_module, self.batch_arg_name):
@@ -209,7 +209,6 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
209209
def _try_loop_run(self, trainer):
210210
if trainer.state.fn == TrainerFn.FITTING:
211211
trainer.fit_loop.global_step = self._dumped_params["global_step"]
212-
# trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
213212
loop = trainer.fit_loop
214213
elif trainer.state.fn == TrainerFn.VALIDATING:
215214
loop = trainer.validate_loop
@@ -235,7 +234,6 @@ def _reset_dataloaders(trainer, pl_module):
235234

236235
def _dump_params(self, trainer):
237236
self._dumped_params = {
238-
# "current_epoch": trainer.current_epoch,
239237
"logger": trainer.logger,
240238
"callbacks": trainer.callbacks,
241239
}
@@ -280,7 +278,6 @@ def _restore_params(self, trainer):
280278
trainer.callbacks = self._dumped_params["callbacks"]
281279

282280
if trainer.state.fn == TrainerFn.FITTING:
283-
# trainer.fit_loop.current_epoch = self._dumped_params["current_epoch"]
284281
trainer.fit_loop.global_step = self._dumped_params["global_step"]
285282
loop = trainer.fit_loop
286283
loop.max_steps = self._dumped_params["max_steps"]
@@ -300,6 +297,9 @@ def _restore_params(self, trainer):
300297
loop.verbose = self._dumped_params["loop_verbose"]
301298

302299
def pre_early_exit(self, trainer):
300+
if trainer.fast_dev_run:
301+
return
302+
303303
if trainer.state.fn == TrainerFn.FITTING:
304304
trainer.should_stop = True
305305
self._dumped_params["num_training_batches"] = trainer.num_training_batches
@@ -315,6 +315,9 @@ def pre_early_exit(self, trainer):
315315
trainer.num_predict_batches = [0]
316316

317317
def post_early_exit(self, trainer):
318+
if trainer.fast_dev_run:
319+
return
320+
318321
if trainer.state.fn == TrainerFn.FITTING:
319322
trainer.num_training_batches = self._dumped_params["num_training_batches"]
320323
loop = trainer.fit_loop

pytorch_lightning/tuner/lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def lr_find(
201201
) -> Optional[_LRFinder]:
202202
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
203203
if trainer.fast_dev_run:
204-
rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.")
204+
rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.")
205205
return
206206

207207
# Determine lr attr

tests/trainer/flags/test_fast_dev_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@pytest.mark.parametrize("tuner_alg", ["batch size scaler", "learning rate finder"])
14-
@pytest.mark.skip(reason="Temperory skip")
14+
# @pytest.mark.skip(reason="Temperory skip")
1515
def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
1616
"""Test that tuner algorithms are skipped if fast dev run is enabled."""
1717

@@ -24,7 +24,7 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
2424
auto_lr_find=(tuner_alg == "learning rate finder"),
2525
fast_dev_run=True,
2626
)
27-
expected_message = f"Skipping {tuner_alg} since fast_dev_run is enabled."
27+
expected_message = f"Skipping {tuner_alg} since `fast_dev_run` is enabled."
2828
with pytest.warns(UserWarning, match=expected_message):
2929
trainer.tune(model)
3030

0 commit comments

Comments
 (0)