Skip to content

Commit bdd7a88

Browse files
committed
mypy and circular import fix
1 parent 3613f20 commit bdd7a88

File tree

4 files changed

+46
-29
lines changed

4 files changed

+46
-29
lines changed

docs/source/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Lightning has a few built-in callbacks.
9090
BackboneFinetuning
9191
BaseFinetuning
9292
BasePredictionWriter
93+
BatchSizeFinder
9394
Callback
9495
DeviceStatsMonitor
9596
EarlyStopping

pytorch_lightning/callbacks/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.callbacks.base import Callback
15+
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
1516
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
1617
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1718
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
@@ -33,6 +34,8 @@
3334
__all__ = [
3435
"BackboneFinetuning",
3536
"BaseFinetuning",
37+
"BasePredictionWriter",
38+
"BatchSizeFinder",
3639
"Callback",
3740
"DeviceStatsMonitor",
3841
"EarlyStopping",
@@ -44,7 +47,6 @@
4447
"ModelCheckpoint",
4548
"ModelPruning",
4649
"ModelSummary",
47-
"BasePredictionWriter",
4850
"ProgressBar",
4951
"ProgressBarBase",
5052
"QuantizationAwareTraining",

pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@
2727

2828
import pytorch_lightning as pl
2929
from pytorch_lightning.callbacks.base import Callback
30-
from pytorch_lightning.loggers.base import DummyLogger
31-
from pytorch_lightning.trainer.states import TrainerFn
3230
from pytorch_lightning.utilities.cloud_io import get_filesystem
33-
from pytorch_lightning.utilities.data import has_len_all_ranks
3431
from pytorch_lightning.utilities.distributed import rank_zero_info
3532
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
3633
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
@@ -42,11 +39,11 @@ class BatchSizeFinder(Callback):
4239
def __init__(
4340
self,
4441
mode: str = "power",
45-
steps_per_trial=3,
46-
init_val=2,
47-
max_trials=25,
48-
batch_arg_name="batch_size",
49-
):
42+
steps_per_trial: int = 3,
43+
init_val: int = 2,
44+
max_trials: int = 25,
45+
batch_arg_name: str = "batch_size",
46+
) -> None:
5047
"""Callback try to find the largest batch size for a given model that does not give an out of memory (OOM)
5148
error. It works with both training and evalation. All you need to do is add it as a callback inside Trainer
5249
and call ``trainer.fit/validate/test/predict()``. Internally it calls the respective step function
@@ -90,7 +87,7 @@ def __init__(
9087

9188
self._early_exit = False
9289

93-
def scale_batch_size(self, trainer, pl_module):
90+
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
9491
if trainer.fast_dev_run:
9592
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
9693
return
@@ -165,7 +162,7 @@ def scale_batch_size(self, trainer, pl_module):
165162
if self._early_exit:
166163
raise _TunerExitException()
167164

168-
def _run_power_scaling(self, trainer, pl_module, new_size):
165+
def _run_power_scaling(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int) -> int:
169166
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
170167
for _ in range(self.max_trials):
171168
garbage_collection_cuda()
@@ -189,7 +186,7 @@ def _run_power_scaling(self, trainer, pl_module, new_size):
189186

190187
return new_size
191188

192-
def _run_binary_scaling(self, trainer, pl_module, new_size):
189+
def _run_binary_scaling(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int) -> int:
193190
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
194191
encountered.
195192
@@ -242,7 +239,9 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
242239

243240
return new_size
244241

245-
def _try_loop_run(self, trainer):
242+
def _try_loop_run(self, trainer: "pl.Trainer") -> None:
243+
from pytorch_lightning.trainer.states import TrainerFn
244+
246245
if trainer.state.fn == TrainerFn.FITTING:
247246
trainer.fit_loop.global_step = self._dumped_params["global_step"]
248247
loop = trainer.fit_loop
@@ -257,7 +256,9 @@ def _try_loop_run(self, trainer):
257256
loop.run()
258257

259258
@staticmethod
260-
def _reset_dataloaders(trainer, pl_module):
259+
def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
260+
from pytorch_lightning.trainer.states import TrainerFn
261+
261262
if trainer.state.fn == TrainerFn.FITTING:
262263
trainer.reset_train_dataloader(pl_module)
263264
trainer.reset_val_dataloader(pl_module)
@@ -268,7 +269,9 @@ def _reset_dataloaders(trainer, pl_module):
268269
elif trainer.state.fn == TrainerFn.PREDICTING:
269270
trainer.reset_predict_dataloader(pl_module)
270271

271-
def _dump_params(self, trainer):
272+
def _dump_params(self, trainer: "pl.Trainer") -> None:
273+
from pytorch_lightning.trainer.states import TrainerFn
274+
272275
self._dumped_params = {
273276
"logger": trainer.logger,
274277
"callbacks": trainer.callbacks,
@@ -293,7 +296,10 @@ def _dump_params(self, trainer):
293296
if hasattr(loop, "verbose"):
294297
self._dumped_params["loop_verbose"] = loop.verbose
295298

296-
def _reset_params(self, trainer):
299+
def _reset_params(self, trainer: "pl.Trainer") -> None:
300+
from pytorch_lightning.loggers.base import DummyLogger
301+
from pytorch_lightning.trainer.states import TrainerFn
302+
297303
trainer.logger = DummyLogger() if trainer.logger is not None else None
298304
trainer.callbacks = []
299305

@@ -309,7 +315,9 @@ def _reset_params(self, trainer):
309315
elif trainer.state.fn == TrainerFn.PREDICTING:
310316
trainer.limit_predict_batches = self.steps_per_trial
311317

312-
def _restore_params(self, trainer):
318+
def _restore_params(self, trainer: "pl.Trainer") -> None:
319+
from pytorch_lightning.trainer.states import TrainerFn
320+
313321
trainer.logger = self._dumped_params["logger"]
314322
trainer.callbacks = self._dumped_params["callbacks"]
315323

@@ -332,19 +340,21 @@ def _restore_params(self, trainer):
332340
if "loop_verbose" in self._dumped_params:
333341
loop.verbose = self._dumped_params["loop_verbose"]
334342

335-
def on_fit_start(self, trainer, pl_module):
343+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
336344
self.scale_batch_size(trainer, pl_module)
337345

338-
def on_validation_start(self, trainer, pl_module):
346+
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
347+
from pytorch_lightning.trainer.states import TrainerFn
348+
339349
if trainer.sanity_checking or trainer.state.fn != TrainerFn.VALIDATING:
340350
return
341351

342352
self.scale_batch_size(trainer, pl_module)
343353

344-
def on_test_start(self, trainer, pl_module):
354+
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
345355
self.scale_batch_size(trainer, pl_module)
346356

347-
def on_predict_start(self, trainer, pl_module):
357+
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
348358
self.scale_batch_size(trainer, pl_module)
349359

350360
def _adjust_batch_size(
@@ -368,6 +378,8 @@ def _adjust_batch_size(
368378
The new batch size for the next trial and a bool that signals whether the
369379
new value is different than the previous batch size.
370380
"""
381+
from pytorch_lightning.trainer.states import TrainerFn
382+
371383
model = trainer.lightning_module
372384
batch_size = lightning_getattr(model, self.batch_arg_name)
373385
new_size = value if value is not None else int(batch_size * factor)
@@ -393,6 +405,8 @@ def _adjust_batch_size(
393405
return new_size, changed
394406

395407
@staticmethod
396-
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
408+
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool:
409+
from pytorch_lightning.utilities.data import has_len_all_ranks
410+
397411
module = trainer.lightning_module or trainer.datamodule
398412
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)

pytorch_lightning/loops/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,19 +294,19 @@ def state_dict(
294294

295295
return destination
296296

297-
def load_state_dict(self, state_dict: Dict, prefix: str = "", metrics: Optional[Dict[str, Metric]] = None) -> None:
297+
def load_state_dict(
298+
self,
299+
state_dict: Dict,
300+
prefix: str = "",
301+
metrics: Optional[Dict[str, Metric]] = None,
302+
) -> None:
298303
"""Loads the state of this loop and all its children."""
299304
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
300305
for k, v in self.__dict__.items():
301306
if isinstance(v, Loop):
302307
v.load_state_dict(state_dict.copy(), prefix + k + ".")
303308

304-
def _load_from_state_dict(
305-
self,
306-
state_dict: Dict,
307-
prefix: str,
308-
metrics: Optional[Dict[str, Metric]] = None,
309-
) -> None:
309+
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
310310
for k, v in self.__dict__.items():
311311
key = prefix + k
312312
if key not in state_dict:

0 commit comments

Comments
 (0)