Skip to content

Optimize import paths for optional dependencies #18561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7cbd473
mock third party packages for autodoc typehints
awaelchli Sep 14, 2023
3ac49e6
mock
awaelchli Sep 14, 2023
e6a6d2c
loggers
awaelchli Sep 14, 2023
5cc17a3
wip
awaelchli Sep 14, 2023
e44bd8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2023
14210f9
revert
awaelchli Sep 14, 2023
ea98db7
debug
awaelchli Sep 14, 2023
d074423
Merge branch 'master' into docs/mock-third-party
awaelchli Sep 14, 2023
ea2f3d9
comet
awaelchli Sep 14, 2023
2c0b5f6
xla
awaelchli Sep 14, 2023
7903138
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2023
7ad26a5
undo
awaelchli Sep 14, 2023
93319c6
fabric deepspeed
awaelchli Sep 14, 2023
f42dd1b
strategy
awaelchli Sep 14, 2023
0417f32
update
awaelchli Sep 14, 2023
01e8664
handle xla
awaelchli Sep 14, 2023
6b2a9c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2023
011a77b
fix
awaelchli Sep 14, 2023
45bb7bc
xlafsdp
awaelchli Sep 14, 2023
182090b
mock for doctest
awaelchli Sep 14, 2023
29c2475
monkeypatch no longer required
awaelchli Sep 14, 2023
26cab23
xla
awaelchli Sep 14, 2023
4418b24
comet_ml is mocked, so we need to skip the doctest
awaelchli Sep 14, 2023
e6b31ed
config
awaelchli Sep 15, 2023
2fd050e
update
awaelchli Sep 15, 2023
25da60a
tf engine
awaelchli Sep 15, 2023
f749e88
fix
awaelchli Sep 15, 2023
a948e32
mock
awaelchli Sep 15, 2023
ef00d44
Update docs/source-fabric/conf.py
awaelchli Sep 15, 2023
838411f
try Jirka's suggestion
awaelchli Sep 15, 2023
a3ce7fb
Revert "try Jirka's suggestion"
awaelchli Sep 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ jobs:
if: ${{ matrix.check == 'doctest' }}
working-directory: ./docs/source-${{ matrix.pkg-name }}
env:
SPHINX_MOCK_REQUIREMENTS: 0
FAST_DOCS_DEV: 1
run: |
make doctest
Expand Down
5 changes: 3 additions & 2 deletions docs/source-fabric/api/strategies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ lightning.fabric.strategies
Strategies
^^^^^^^^^^

.. TODO(fabric): include DeepSpeedStrategy, XLAStrategy

.. currentmodule:: lightning.fabric.strategies

.. autosummary::
Expand All @@ -21,6 +19,9 @@ Strategies
DDPStrategy
DataParallelStrategy
FSDPStrategy
DeepSpeedStrategy
XLAStrategy
XLAFSDPStrategy
ParallelStrategy
SingleDeviceStrategy
SingleDeviceXLAStrategy
9 changes: 8 additions & 1 deletion docs/source-fabric/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,12 @@
"torch": ("https://pytorch.org/docs/stable/", None),
"pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
"tensorboardX": ("https://tensorboardx.readthedocs.io/en/stable/", None),
"deepspeed": ("https://deepspeed.readthedocs.io/en/stable/", None),
"torch_xla": ("https://pytorch.org/xla/release/2.0/", None),
}
nitpicky = True

nitpick_ignore = [
nitpick_ignore_regex = [
("py:class", "typing.Self"),
# these are not generated with docs API ref
("py:class", "lightning.fabric.utilities.types.Optimizable"),
Expand All @@ -271,6 +273,10 @@
# These seem to be missing in reference generated API
("py:class", "torch.distributed.fsdp.wrap.ModuleWrapPolicy"),
("py:class", "torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler"),
# Mocked optional packages
("py:class", "deepspeed.*"),
("py:.*", "torch_xla.*"),
("py:class", "transformer_engine.*"),
]

# -- Options for todo extension ----------------------------------------------
Expand Down Expand Up @@ -308,6 +314,7 @@ def _package_list_from_file(file):
if _SPHINX_MOCK_REQUIREMENTS:
# mock also base packages when we are on RTD since we don't install them there
MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt"))
MOCK_PACKAGES += ["deepspeed", "torch_xla", "transformer_engine"]
MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]

autodoc_mock_imports = MOCK_PACKAGES
Expand Down
3 changes: 3 additions & 0 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def package_list_from_file(file):
MOCK_PACKAGES += package_list_from_file(_path_require("base.txt"))
MOCK_PACKAGES += package_list_from_file(_path_require("extra.txt"))
MOCK_PACKAGES += package_list_from_file(_path_require("strategies.txt"))
MOCK_PACKAGES += package_list_from_file(_path_require("loggers.info"))
MOCK_PACKAGES += ["comet_ml", "torch_xla", "transformer_engine"]
MOCK_PACKAGES.remove("jsonargparse")
MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]

autodoc_mock_imports = MOCK_PACKAGES
Expand Down
6 changes: 2 additions & 4 deletions docs/source-pytorch/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,14 @@ You can also pass a custom Logger to the :class:`~lightning.pytorch.trainer.trai

Choose from any of the others such as MLflow, Comet, Neptune, WandB, etc.

.. testcode::
:skipif: not _COMET_AVAILABLE
.. code-block:: python

comet_logger = pl_loggers.CometLogger(save_dir="logs/")
trainer = Trainer(logger=comet_logger)

To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers.

.. testcode::
:skipif: (not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE) or not _COMET_AVAILABLE
.. code-block:: python

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
Expand Down
6 changes: 2 additions & 4 deletions docs/source-pytorch/visualize/supported_exp_managers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ To use `Comet.ml <https://www.comet.ml/site/>`_ first install the comet package:

Configure the logger and pass it to the :class:`~lightning.pytorch.trainer.trainer.Trainer`:

.. testcode::
:skipif: not _COMET_AVAILABLE
.. code-block:: python

from lightning.pytorch.loggers import CometLogger

Expand Down Expand Up @@ -40,8 +39,7 @@ To use `MLflow <https://mlflow.org/>`_ first install the MLflow package:

Configure the logger and pass it to the :class:`~lightning.pytorch.trainer.trainer.Trainer`:

.. testcode::
:skipif: not _MLFLOW_AVAILABLE
.. code-block:: python

from lightning.pytorch.loggers import MLFlowLogger

Expand Down
7 changes: 2 additions & 5 deletions src/lightning/fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
from lightning.fabric.utilities.types import Steppable

if TYPE_CHECKING:
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE

if _DEEPSPEED_AVAILABLE: # type: ignore[has-type]
import deepspeed
from deepspeed import DeepSpeedEngine

_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true", "16-mixed", "bf16-mixed"]

Expand Down Expand Up @@ -88,7 +85,7 @@ def convert_input(self, data: Any) -> Any:
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

def backward(self, tensor: Tensor, model: "deepspeed.DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
"""Performs back-propagation using DeepSpeed's engine."""
model.backward(tensor, *args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.rank_zero import rank_zero_warn

_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")

if TYPE_CHECKING and _TRANSFORMER_ENGINE_AVAILABLE:
if TYPE_CHECKING:
from transformer_engine.common.recipe import DelayedScaling
else:
DelayedScaling = None


_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
log = logging.getLogger(__name__)


Expand Down
21 changes: 11 additions & 10 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH

if TYPE_CHECKING:
from deepspeed import DeepSpeedEngine

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
import deepspeed


# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
Expand Down Expand Up @@ -289,7 +290,7 @@ def __init__(
self.hysteresis = hysteresis
self.min_loss_scale = min_loss_scale

self._deepspeed_engine: Optional["deepspeed.DeepSpeedEngine"] = None
self._deepspeed_engine: Optional["DeepSpeedEngine"] = None

@property
def zero_stage_3(self) -> bool:
Expand All @@ -302,12 +303,12 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:
return {"num_replicas": self.world_size, "rank": self.global_rank}

@property
def model(self) -> "deepspeed.DeepSpeedEngine":
def model(self) -> "DeepSpeedEngine":
return self._deepspeed_engine

def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]:
) -> Tuple["DeepSpeedEngine", List[Optimizer]]:
"""Set up a model and multiple optimizers together.

Currently, only a single optimizer is supported.
Expand All @@ -327,7 +328,7 @@ def setup_module_and_optimizers(
self._set_deepspeed_activation_checkpointing()
return self._deepspeed_engine, [optimizer]

def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine":
def setup_module(self, module: Module) -> "DeepSpeedEngine":
"""Set up a module for inference (no optimizers).

For training, see :meth:`setup_module_and_optimizers`.
Expand Down Expand Up @@ -515,7 +516,7 @@ def load_checkpoint(

def clip_gradients_norm(
self,
module: "deepspeed.DeepSpeedEngine",
module: "DeepSpeedEngine",
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
Expand All @@ -527,7 +528,7 @@ def clip_gradients_norm(
)

def clip_gradients_value(
self, module: "deepspeed.DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
) -> None:
raise NotImplementedError(
"DeepSpeed handles gradient clipping automatically within the optimizer. "
Expand Down Expand Up @@ -571,7 +572,7 @@ def _initialize_engine(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
) -> Tuple["DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.

This calls :func:`deepspeed.initialize` internally.
Expand Down Expand Up @@ -790,7 +791,7 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
return config


def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["deepspeed.DeepSpeedEngine"]:
def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["DeepSpeedEngine"]:
from deepspeed import DeepSpeedEngine

modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data import DataLoader

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1
from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
Expand All @@ -32,7 +32,7 @@
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH, ReduceOp

if TYPE_CHECKING and _XLA_AVAILABLE:
if TYPE_CHECKING:
from torch_xla.distributed.parallel_loader import MpDeviceLoader


Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.utils.data import DataLoader

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_AVAILABLE
from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
Expand All @@ -39,7 +39,7 @@
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp

if TYPE_CHECKING and _XLA_AVAILABLE:
if TYPE_CHECKING:
from torch_xla.distributed.parallel_loader import MpDeviceLoader

_POLICY_SET = Set[Type[Module]]
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
from argparse import Namespace
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union

from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
Expand All @@ -30,6 +30,9 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_only

if TYPE_CHECKING:
from comet_ml import ExistingExperiment, Experiment, OfflineExperiment

log = logging.getLogger(__name__)
_COMET_AVAILABLE = RequirementCache("comet-ml>=3.31.0", module="comet_ml")

Expand Down Expand Up @@ -255,7 +258,7 @@ def __init__(

@property
@rank_zero_experiment
def experiment(self) -> Any:
def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment"]:
r"""
Actual Comet object. To use Comet features in your
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from argparse import Namespace
from pathlib import Path
from time import time
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, TYPE_CHECKING, Union

import yaml
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -34,6 +34,9 @@
from lightning.pytorch.loggers.utilities import _scan_checkpoints
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn

if TYPE_CHECKING:
from mlflow.tracking import MlflowClient

log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
Expand Down Expand Up @@ -142,7 +145,7 @@ def __init__(

@property
@rank_zero_experiment
def experiment(self) -> Any:
def experiment(self) -> "MlflowClient":
r"""
Actual MLflow object. To use MLflow features in your
:class:`~lightning.pytorch.core.module.LightningModule` do the following.
Expand Down
13 changes: 9 additions & 4 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
from typing import Any, Dict, List, Literal, Mapping, Optional, TYPE_CHECKING, Union

import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -32,6 +32,11 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn

if TYPE_CHECKING:
from wandb import Artifact
from wandb.sdk.lib import RunDisabled
from wandb.wandb_run import Run

_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10")


Expand Down Expand Up @@ -288,7 +293,7 @@ def __init__(
anonymous: Optional[bool] = None,
project: Optional[str] = None,
log_model: Union[Literal["all"], bool] = False,
experiment: Optional[Any] = None,
experiment: Union["Run", "RunDisabled", None] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
**kwargs: Any,
Expand Down Expand Up @@ -359,7 +364,7 @@ def __getstate__(self) -> Dict[str, Any]:

@property
@rank_zero_experiment
def experiment(self) -> Any:
def experiment(self) -> Union["Run", "RunDisabled"]:
r"""

Actual wandb object. To use wandb features in your
Expand Down Expand Up @@ -549,7 +554,7 @@ def download_artifact(
save_dir = None if save_dir is None else os.fspath(save_dir)
return artifact.download(root=save_dir)

def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> Any:
def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "Artifact":
"""Logs to the wandb dashboard that the mentioned artifact is used by the run.

Args:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@
import lightning.pytorch as pl
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.types import Steppable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import WarningCache

if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
if TYPE_CHECKING:
import deepspeed

warning_cache = WarningCache()
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
log = logging.getLogger(__name__)
warning_cache = WarningCache()

if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
if TYPE_CHECKING:
import deepspeed


Expand Down
Loading