Skip to content

Commit 697f250

Browse files
authored
Prefer local imports for optional dependencies (#18551)
1 parent 1cee84c commit 697f250

File tree

9 files changed

+40
-40
lines changed

9 files changed

+40
-40
lines changed

src/lightning/pytorch/callbacks/rich_model_summary.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
1818
from lightning.pytorch.utilities.model_summary import get_human_readable_count
1919

20-
if _RICH_AVAILABLE: # type: ignore[has-type]
21-
from rich import get_console
22-
from rich.table import Table
23-
2420

2521
class RichModelSummary(ModelSummary):
2622
r"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule` with `rich text
@@ -74,6 +70,9 @@ def summarize(
7470
model_size: float,
7571
**summarize_kwargs: Any,
7672
) -> None:
73+
from rich import get_console
74+
from rich.table import Table
75+
7776
console = get_console()
7877

7978
header_style: str = summarize_kwargs.get("header_style", "bold magenta")

src/lightning/pytorch/core/saving.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,10 @@
4141
from lightning.pytorch.utilities.parsing import AttributeDict, parse_class_init_keys
4242
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
4343

44-
log = logging.getLogger(__name__)
45-
46-
if _OMEGACONF_AVAILABLE:
47-
from omegaconf import OmegaConf
48-
from omegaconf.dictconfig import DictConfig
49-
from omegaconf.errors import UnsupportedValueType, ValidationError
50-
5144
if TYPE_CHECKING:
5245
from torch.storage import UntypedStorage
5346

47+
log = logging.getLogger(__name__)
5448
# the older shall be on the top
5549
CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6
5650

@@ -296,6 +290,9 @@ def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Di
296290
hparams = yaml.full_load(fp)
297291

298292
if _OMEGACONF_AVAILABLE and use_omegaconf:
293+
from omegaconf import OmegaConf
294+
from omegaconf.errors import UnsupportedValueType, ValidationError
295+
299296
with contextlib.suppress(UnsupportedValueType, ValidationError):
300297
return OmegaConf.create(hparams)
301298
return hparams
@@ -322,6 +319,10 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us
322319

323320
# saving with OmegaConf objects
324321
if _OMEGACONF_AVAILABLE and use_omegaconf:
322+
from omegaconf import OmegaConf
323+
from omegaconf.dictconfig import DictConfig
324+
from omegaconf.errors import UnsupportedValueType, ValidationError
325+
325326
# deepcopy: hparams from user shouldn't be resolved
326327
hparams = deepcopy(hparams)
327328
hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True)

src/lightning/pytorch/demos/mnist_datamodule.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
from lightning.pytorch import LightningDataModule
2929
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
3030

31-
if _TORCHVISION_AVAILABLE:
32-
from torchvision import transforms as transform_lib
33-
3431
_DATASETS_PATH = "./data"
3532

3633

@@ -244,11 +241,14 @@ def test_dataloader(self) -> DataLoader:
244241
def default_transforms(self) -> Optional[Callable]:
245242
if not _TORCHVISION_AVAILABLE:
246243
return None
244+
245+
from torchvision import transforms
246+
247247
if self.normalize:
248-
mnist_transforms = transform_lib.Compose(
249-
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
248+
mnist_transforms = transforms.Compose(
249+
[transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))]
250250
)
251251
else:
252-
mnist_transforms = transform_lib.ToTensor()
252+
mnist_transforms = transforms.ToTensor()
253253

254254
return mnist_transforms

src/lightning/pytorch/loggers/tensorboard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737

3838
log = logging.getLogger(__name__)
3939

40-
if _OMEGACONF_AVAILABLE:
41-
from omegaconf import Container, OmegaConf
42-
4340
# Skip doctests if requirements aren't available
4441
if not (_TENSORBOARD_AVAILABLE or _TENSORBOARDX_AVAILABLE):
4542
__doctest_skip__ = ["TensorBoardLogger", "TensorBoardLogger.*"]
@@ -178,6 +175,9 @@ def log_hyperparams( # type: ignore[override]
178175
metrics: Dictionary with metric names as keys and measured quantities as values
179176
180177
"""
178+
if _OMEGACONF_AVAILABLE:
179+
from omegaconf import Container, OmegaConf
180+
181181
params = _convert_params(params)
182182

183183
# store params to output

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
from lightning.pytorch.utilities.model_helpers import is_overridden
4545
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4646

47-
if _RICH_AVAILABLE:
48-
from rich import get_console
49-
from rich.table import Column, Table
50-
5147

5248
class _EvaluationLoop(_Loop):
5349
"""Top-level loop where validation/testing starts."""
@@ -531,6 +527,9 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None:
531527
table_headers.insert(0, f"{stage} Metric".capitalize())
532528

533529
if _RICH_AVAILABLE:
530+
from rich import get_console
531+
from rich.table import Column, Table
532+
534533
columns = [Column(h, justify="center", style="magenta", width=max_length) for h in table_headers]
535534
columns[0].style = "cyan"
536535

src/lightning/pytorch/strategies/ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
from contextlib import nullcontext
1616
from datetime import timedelta
17-
from typing import Any, Callable, Dict, List, Literal, Optional, Union
17+
from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union
1818

1919
import torch
2020
import torch.distributed
@@ -47,7 +47,7 @@
4747
from lightning.pytorch.utilities.exceptions import _augment_message
4848
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_only
4949

50-
if torch.distributed.is_available():
50+
if TYPE_CHECKING:
5151
from torch.distributed.algorithms.model_averaging.averagers import ModelAverager
5252

5353
log = logging.getLogger(__name__)

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,7 @@
3535
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
3636
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3737

38-
if _OMEGACONF_AVAILABLE:
39-
from omegaconf import Container
40-
41-
42-
log: logging.Logger = logging.getLogger(__name__)
38+
log = logging.getLogger(__name__)
4339

4440

4541
class _CheckpointConnector:
@@ -464,6 +460,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
464460
checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
465461
prec_plugin.on_save_checkpoint(checkpoint)
466462

463+
if _OMEGACONF_AVAILABLE:
464+
from omegaconf import Container
465+
467466
# dump hyper-parameters
468467
for obj in (model, datamodule):
469468
if obj and obj.hparams:

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@
3636
else:
3737
from tqdm import tqdm
3838

39-
_MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib")
40-
if TYPE_CHECKING and _MATPLOTLIB_AVAILABLE:
39+
if TYPE_CHECKING:
4140
import matplotlib.pyplot as plt
4241
from matplotlib.axes import Axes
42+
43+
_MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib")
4344
log = logging.getLogger(__name__)
4445

4546

src/lightning/pytorch/utilities/deepspeed.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,6 @@
2323
from lightning.fabric.utilities.types import _PATH
2424
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE
2525

26-
if _DEEPSPEED_AVAILABLE:
27-
from deepspeed.utils.zero_to_fp32 import (
28-
get_fp32_state_dict_from_zero_checkpoint,
29-
get_model_state_file,
30-
get_optim_files,
31-
)
32-
3326
CPU_DEVICE = torch.device("cpu")
3427

3528

@@ -75,6 +68,14 @@ def convert_zero_checkpoint_to_fp32_state_dict(
7568
"lightning_model.pt"
7669
)
7770
"""
71+
if not _DEEPSPEED_AVAILABLE:
72+
raise ModuleNotFoundError(str(_DEEPSPEED_AVAILABLE))
73+
74+
from deepspeed.utils.zero_to_fp32 import (
75+
get_fp32_state_dict_from_zero_checkpoint,
76+
get_model_state_file,
77+
get_optim_files,
78+
)
7879

7980
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
8081

0 commit comments

Comments
 (0)