Skip to content

Commit 22813c4

Browse files
authored
Avoid torchelastic warning message when importing lightning (#15610)
1 parent f1bd68d commit 22813c4

File tree

6 files changed

+81
-6
lines changed

6 files changed

+81
-6
lines changed

src/lightning_lite/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5050

5151
- Fix an issue with the SLURM `srun` detection causing permission errors ([#15485](https://github.com/Lightning-AI/lightning/issues/15485))
5252

53-
-
53+
- Fixed the import of `lightning_lite` causing a warning 'Redirects are currently not supported in Windows or MacOs' ([#15610](https://github.com/PyTorchLightning/pytorch-lightning/issues/15610))
5454

5555
-

src/lightning_lite/strategies/deepspeed.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import platform
1919
from contextlib import contextmanager
2020
from pathlib import Path
21-
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
21+
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
2222

2323
import torch
2424
from lightning_utilities.core.imports import RequirementCache
@@ -38,7 +38,7 @@
3838
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
3939

4040
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
41-
if _DEEPSPEED_AVAILABLE:
41+
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
4242
import deepspeed
4343

4444

@@ -271,6 +271,9 @@ def __init__(
271271
reduce_bucket_size=reduce_bucket_size,
272272
sub_group_size=sub_group_size,
273273
)
274+
275+
import deepspeed
276+
274277
self._config_initialized = False
275278
deepspeed.utils.logging.logger.setLevel(logging_level)
276279

@@ -327,6 +330,9 @@ def module_sharded_context(self) -> Generator[None, None, None]:
327330
# Current limitation in Lite: The config needs to be fully determined at the time of calling the
328331
# context manager, which happens at the start of `Lite.run()`. Later modificatoins through e.g. `Lite.setup()`
329332
# won't have an effect here.
333+
334+
import deepspeed
335+
330336
if self.zero_stage_3:
331337
assert self._config_initialized
332338

@@ -405,6 +411,8 @@ def _setup_module_and_optimizer(
405411
406412
This calls :func:`deepspeed.initialize` internally.
407413
"""
414+
import deepspeed
415+
408416
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
409417
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
410418
args=argparse.Namespace(device_rank=self.root_device.index),
@@ -432,6 +440,8 @@ def _setup_distributed(self) -> None:
432440
self._config_initialized = True
433441

434442
def _init_deepspeed_distributed(self) -> None:
443+
import deepspeed
444+
435445
assert self.cluster_environment is not None
436446
if platform.system() != "Windows":
437447
# do not set env variables on windows, allow deepspeed to control setup
@@ -453,6 +463,8 @@ def _set_node_environment_variables(self) -> None:
453463
os.environ["LOCAL_RANK"] = str(self.local_rank)
454464

455465
def _set_deepspeed_activation_checkpointing(self) -> None:
466+
import deepspeed
467+
456468
assert isinstance(self.config, dict)
457469
if self.config.get("activation_checkpointing"):
458470
checkpoint_config = self.config["activation_checkpointing"]
@@ -573,6 +585,7 @@ def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None:
573585
Args:
574586
ckpt: The ckpt file.
575587
"""
588+
import deepspeed
576589

577590
def load(module: torch.nn.Module, prefix: str = "") -> None:
578591

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7171

7272
- Fixed model state transfer in multiprocessing launcher when running multi-node ([#15567](https://github.com/Lightning-AI/lightning/pull/15567))
7373

74-
- Fixed manual optimization raising `AttributeError` with Bagua Strategy ([[#12534](https://github.com/PyTorchLightning/pytorch-lightning/issues/12534)])
74+
- Fixed manual optimization raising `AttributeError` with Bagua Strategy ([#12534](https://github.com/PyTorchLightning/pytorch-lightning/issues/12534))
75+
76+
- Fixed the import of `pytorch_lightning` causing a warning 'Redirects are currently not supported in Windows or MacOs' ([#15610](https://github.com/PyTorchLightning/pytorch-lightning/issues/15610))
7577

7678

7779
## [1.8.0] - 2022-11-01

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import platform
2020
from collections import OrderedDict
2121
from pathlib import Path
22-
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
22+
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
2323

2424
import torch
2525
from lightning_utilities.core.apply_func import apply_to_collection
@@ -52,7 +52,7 @@
5252
warning_cache = WarningCache()
5353

5454
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
55-
if _DEEPSPEED_AVAILABLE:
55+
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
5656
import deepspeed
5757

5858

@@ -319,6 +319,8 @@ def __init__(
319319
reduce_bucket_size=reduce_bucket_size,
320320
sub_group_size=sub_group_size,
321321
)
322+
import deepspeed
323+
322324
self._config_initialized = False
323325
deepspeed.utils.logging.logger.setLevel(logging_level)
324326

@@ -368,6 +370,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
368370
self.barrier()
369371

370372
def _init_deepspeed_distributed(self) -> None:
373+
import deepspeed
374+
371375
assert self.cluster_environment is not None
372376
if platform.system() != "Windows":
373377
# do not set env variables on windows, allow deepspeed to control setup
@@ -428,6 +432,8 @@ def _setup_model_and_optimizer(
428432
429433
This calls :func:`deepspeed.initialize` internally.
430434
"""
435+
import deepspeed
436+
431437
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
432438
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
433439
args=argparse.Namespace(device_rank=self.root_device.index),
@@ -527,6 +533,8 @@ def _initialize_deepspeed_train(self, model: Module) -> None:
527533

528534
@contextlib.contextmanager
529535
def model_sharded_context(self) -> Generator[None, None, None]:
536+
import deepspeed
537+
530538
if self.zero_stage_3:
531539
assert self._config_initialized
532540

@@ -547,6 +555,8 @@ def model_sharded_context(self) -> Generator[None, None, None]:
547555
yield
548556

549557
def _set_deepspeed_activation_checkpointing(self) -> None:
558+
import deepspeed
559+
550560
assert isinstance(self.config, dict)
551561
if self.config.get("activation_checkpointing"):
552562
checkpoint_config = self.config["activation_checkpointing"]
@@ -559,6 +569,8 @@ def _set_deepspeed_activation_checkpointing(self) -> None:
559569
)
560570

561571
def _initialize_deepspeed_inference(self, model: Module) -> None:
572+
import deepspeed
573+
562574
assert isinstance(self.config, dict)
563575

564576
# todo: this is required for DeepSpeed throughput timers
@@ -639,6 +651,8 @@ def _format_batch_size_and_grad_accum_config(self) -> None:
639651
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0
640652

641653
def _auto_select_batch_size(self) -> int:
654+
import deepspeed
655+
642656
# train_micro_batch_size_per_gpu is used for throughput logging purposes
643657
# by default we try to use the batch size of the loader
644658
assert self.lightning_module is not None
@@ -842,6 +856,7 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
842856
Args:
843857
ckpt: The ckpt file.
844858
"""
859+
import deepspeed
845860

846861
assert self.lightning_module is not None
847862

tests/tests_lite/utilities/test_imports.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import sys
1616
from textwrap import dedent
1717

18+
from tests_lite.helpers.runif import RunIf
19+
1820
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
1921
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
2022

@@ -46,3 +48,24 @@ def test_import_lightning_lite_with_torch_dist_unavailable():
4648
)
4749
# run in complete isolation
4850
assert subprocess.call([sys.executable, "-c", code]) == 0
51+
52+
53+
@RunIf(deepspeed=True)
54+
def test_import_deepspeed_lazily():
55+
"""Test that we are importing deepspeed only when necessary."""
56+
code = dedent(
57+
"""
58+
import lightning_lite
59+
import sys
60+
61+
assert 'deepspeed' not in sys.modules
62+
from lightning_lite.strategies import DeepSpeedStrategy
63+
from lightning_lite.plugins import DeepSpeedPrecision
64+
assert 'deepspeed' not in sys.modules
65+
66+
import deepspeed
67+
assert 'deepspeed' in sys.modules
68+
"""
69+
)
70+
# run in complete isolation
71+
assert subprocess.call([sys.executable, "-c", code]) == 0

tests/tests_pytorch/utilities/test_imports.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE
2727
from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE
28+
from tests_pytorch.helpers.runif import RunIf
2829

2930

3031
def test_imports():
@@ -157,3 +158,24 @@ def test_import_pytorch_lightning_with_torch_dist_unavailable():
157158
)
158159
# run in complete isolation
159160
assert subprocess.call([sys.executable, "-c", code]) == 0
161+
162+
163+
@RunIf(deepspeed=True)
164+
def test_import_deepspeed_lazily():
165+
"""Test that we are importing deepspeed only when necessary."""
166+
code = dedent(
167+
"""
168+
import pytorch_lightning
169+
import sys
170+
171+
assert 'deepspeed' not in sys.modules
172+
from pytorch_lightning.strategies import DeepSpeedStrategy
173+
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
174+
assert 'deepspeed' not in sys.modules
175+
176+
import deepspeed
177+
assert 'deepspeed' in sys.modules
178+
"""
179+
)
180+
# run in complete isolation
181+
assert subprocess.call([sys.executable, "-c", code]) == 0

0 commit comments

Comments
 (0)