Skip to content

Commit 5ec2c83

Browse files
committed
temp rm from init
1 parent 3fb7fea commit 5ec2c83

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

pytorch_lightning/callbacks/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.callbacks.base import Callback
15-
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
1615
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
1716
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1817
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
@@ -34,7 +33,6 @@
3433
__all__ = [
3534
"BackboneFinetuning",
3635
"BaseFinetuning",
37-
"BatchSizeFinder",
3836
"Callback",
3937
"DeviceStatsMonitor",
4038
"EarlyStopping",

pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Finds optimal batch size
1919
"""
2020

21-
import logging
2221
import os
2322
import uuid
2423
from typing import Optional, Tuple
@@ -31,13 +30,12 @@
3130
from pytorch_lightning.trainer.states import TrainerFn
3231
from pytorch_lightning.utilities.cloud_io import get_filesystem
3332
from pytorch_lightning.utilities.data import has_len_all_ranks
33+
from pytorch_lightning.utilities.distributed import rank_zero_info
3434
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3535
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
3636
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
3737
from pytorch_lightning.utilities.warnings import rank_zero_warn
3838

39-
log = logging.getLogger(__name__)
40-
4139

4240
class BatchSizeFinder(Callback):
4341
def __init__(self, mode: str = "power", steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name="batch_size"):
@@ -124,11 +122,16 @@ def _run_power_scaling(self, trainer, pl_module, new_size):
124122
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
125123
for _ in range(self.max_trials):
126124
garbage_collection_cuda()
127-
changed = False
128125

129126
try:
130127
self._try_loop_run(trainer)
131128
new_size, changed = self._adjust_batch_size(trainer, factor=2.0, desc="succeeded")
129+
130+
if changed:
131+
# Force the dataloaders to reset as the batch size has changed
132+
self._reset_dataloaders(trainer, pl_module)
133+
else:
134+
break
132135
except RuntimeError as exception:
133136
if is_oom_error(exception):
134137
garbage_collection_cuda()
@@ -137,12 +140,6 @@ def _run_power_scaling(self, trainer, pl_module, new_size):
137140
else:
138141
raise # some other error not memory related
139142

140-
if changed:
141-
# Force the train dataloader to reset as the batch size has changed
142-
self._reset_dataloaders(trainer, pl_module)
143-
else:
144-
break
145-
146143
return new_size
147144

148145
def _run_binary_scaling(self, trainer, pl_module, new_size):
@@ -156,7 +153,6 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
156153
count = 0
157154
while True:
158155
garbage_collection_cuda()
159-
trainer.fit_loop.global_step = 0 # reset after each try
160156
try:
161157
# Try fit
162158
self._try_loop_run(trainer)
@@ -174,7 +170,7 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
174170
new_size, changed = self._adjust_batch_size(trainer, factor=2.0, desc="succeeded")
175171

176172
if changed:
177-
# Force the train dataloader to reset as the batch size has changed
173+
# Force the dataloaders to reset as the batch size has changed
178174
self._reset_dataloaders(trainer, pl_module)
179175
else:
180176
break
@@ -287,7 +283,7 @@ def _adjust_batch_size(
287283
new batch size
288284
value: if a value is given, will override the batch size with this value.
289285
Note that the value of `factor` will not have an effect in this case
290-
desc: either `succeeded` or `failed`. Used purely for logging
286+
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
291287
292288
Returns:
293289
The new batch size for the next trial and a bool that signals whether the
@@ -297,7 +293,7 @@ def _adjust_batch_size(
297293
batch_size = lightning_getattr(model, self.batch_arg_name)
298294
new_size = value if value is not None else int(batch_size * factor)
299295
if desc:
300-
log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
296+
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
301297

302298
# TODO improve this for CombinedLoader
303299
if trainer.state.fn == TrainerFn.FITTING:

0 commit comments

Comments
 (0)