diff --git a/src/lightning_lite/plugins/__init__.py b/src/lightning_lite/plugins/__init__.py index 54aa3a4e4e113..85b863ea02b66 100644 --- a/src/lightning_lite/plugins/__init__.py +++ b/src/lightning_lite/plugins/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO @@ -28,10 +27,10 @@ "CheckpointIO", "TorchCheckpointIO", "XLACheckpointIO", + "Precision", "DeepSpeedPrecision", "DoublePrecision", "NativeMixedPrecision", - "Precision", "TPUPrecision", "TPUBf16Precision", ] diff --git a/src/lightning_lite/plugins/collectives/__init__.py b/src/lightning_lite/plugins/collectives/__init__.py new file mode 100644 index 0000000000000..90a47c8c4159c --- /dev/null +++ b/src/lightning_lite/plugins/collectives/__init__.py @@ -0,0 +1,9 @@ +from lightning_lite.plugins.collectives.collective import Collective +from lightning_lite.plugins.collectives.single_device_collective import SingleDeviceCollective +from lightning_lite.plugins.collectives.torch_collective import TorchCollective + +__all__ = [ + "Collective", + "TorchCollective", + "SingleDeviceCollective", +] diff --git a/src/lightning_lite/plugins/collectives/collective.py b/src/lightning_lite/plugins/collectives/collective.py new file mode 100644 index 0000000000000..f2e7f896b3547 --- /dev/null +++ b/src/lightning_lite/plugins/collectives/collective.py @@ -0,0 +1,137 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +import torch +from typing_extensions import Self + +from lightning_lite.utilities.types import CollectibleGroup + + +class Collective(ABC): + """Interface for collective operations. + + Supports communications between multiple processes and multiple nodes. A collective owns a group. + + .. warning:: + This API is experimental and subject to change + """ + + def __init__(self) -> None: + self._group: Optional[CollectibleGroup] = None + + @property + @abstractmethod + def rank(self) -> int: + ... + + @property + @abstractmethod + def world_size(self) -> int: + ... + + @property + def group(self) -> CollectibleGroup: + if self._group is None: + raise RuntimeError( + f"`{type(self).__name__}` does not own a group. HINT: try `collective.create_group().group`" + ) + return self._group + + @abstractmethod + def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + ... + + @abstractmethod + def all_reduce(self, tensor: torch.Tensor, op: str) -> torch.Tensor: + ... + + @abstractmethod + def reduce(self, tensor: torch.Tensor, dst: int, op: str) -> torch.Tensor: + ... + + @abstractmethod + def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]: + ... + + @abstractmethod + def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]: + ... + + @abstractmethod + def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor: + ... + + @abstractmethod + def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], op: str) -> torch.Tensor: + ... + + @abstractmethod + def all_to_all( + self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor] + ) -> List[torch.Tensor]: + ... + + @abstractmethod + def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None: + ... + + @abstractmethod + def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor: + ... + + @abstractmethod + def barrier(self, device_ids: Optional[List[int]] = None) -> None: + ... + + @classmethod + @abstractmethod + def is_available(cls) -> bool: + ... + + @classmethod + @abstractmethod + def is_initialized(cls) -> bool: + ... + + @classmethod + @abstractmethod + def init_group(cls, **kwargs: Any) -> None: + ... + + @classmethod + @abstractmethod + def new_group(cls, **kwargs: Any) -> CollectibleGroup: + ... + + @classmethod + @abstractmethod + def destroy_group(cls, group: CollectibleGroup) -> None: + ... + + @classmethod + @abstractmethod + def _convert_to_native_op(cls, op: str) -> Any: + ... + + def setup(self, **kwargs: Any) -> Self: # type: ignore[valid-type] + if not self.is_initialized(): + self.init_group(**kwargs) + return self + + def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type] + """Create a group. + + This assumes that :meth:`~lightning_lite.plugins.collectives.Collective.init_group` has been + called already by the user. + """ + if self._group is not None: + raise RuntimeError(f"`{type(self).__name__}` already owns a group.") + self._group = self.new_group(**kwargs) + return self + + def teardown(self) -> Self: # type: ignore[valid-type] + if self._group is None: + raise RuntimeError(f"`{type(self).__name__}` does not own a group to destroy.") + self.destroy_group(self._group) + self._group = None + return self diff --git a/src/lightning_lite/plugins/collectives/single_device_collective.py b/src/lightning_lite/plugins/collectives/single_device_collective.py new file mode 100644 index 0000000000000..1bc524192cf34 --- /dev/null +++ b/src/lightning_lite/plugins/collectives/single_device_collective.py @@ -0,0 +1,81 @@ +from typing import Any, List + +import torch + +from lightning_lite.plugins.collectives.collective import Collective +from lightning_lite.utilities.types import CollectibleGroup + + +class SingleDeviceCollective(Collective): + @property + def rank(self) -> int: + return 0 + + @property + def world_size(self) -> int: + return 1 + + def broadcast(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + return tensor + + def all_reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + return tensor + + def reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + return tensor + + def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor, **__: Any) -> List[torch.Tensor]: + return [tensor] + + def gather(self, tensor: torch.Tensor, *_: Any, **__: Any) -> List[torch.Tensor]: + return [tensor] + + def scatter( + self, + tensor: torch.Tensor, + scatter_list: List[torch.Tensor], + *_: Any, + **__: Any, + ) -> torch.Tensor: + return scatter_list[0] + + def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], *_: Any, **__: Any) -> torch.Tensor: + return input_list[0] + + def all_to_all( + self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor], *_: Any, **__: Any + ) -> List[torch.Tensor]: + return input_tensor_list + + def send(self, *_: Any, **__: Any) -> None: + pass + + def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor: + return tensor + + def barrier(self, *_: Any, **__: Any) -> None: + pass + + @classmethod + def is_available(cls) -> bool: + return True # vacuous truth + + @classmethod + def is_initialized(cls) -> bool: + return True # vacuous truth + + @classmethod + def init_group(cls, **_: Any) -> None: + pass + + @classmethod + def new_group(cls, **_: Any) -> CollectibleGroup: + return object() # type: ignore[return-value] + + @classmethod + def destroy_group(cls, group: CollectibleGroup) -> None: + pass + + @classmethod + def _convert_to_native_op(cls, op: str) -> str: + return op diff --git a/src/lightning_lite/plugins/collectives/torch_collective.py b/src/lightning_lite/plugins/collectives/torch_collective.py new file mode 100644 index 0000000000000..6361657480a5b --- /dev/null +++ b/src/lightning_lite/plugins/collectives/torch_collective.py @@ -0,0 +1,167 @@ +import datetime +import os +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +from typing_extensions import Self + +from lightning_lite.plugins.collectives.collective import Collective +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10 +from lightning_lite.utilities.types import CollectibleGroup, ReduceOp + +if dist.is_available(): + from torch.distributed.constants import default_pg_timeout +else: + default_pg_timeout = datetime.timedelta(seconds=1800) + + +class TorchCollective(Collective): + def __init__(self) -> None: + if not dist.is_available(): + raise RuntimeError("Torch distributed is not available.") + super().__init__() + + @property + def rank(self) -> int: + return dist.get_rank(self.group) + + @property + def world_size(self) -> int: + return dist.get_world_size(self.group) + + def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + dist.broadcast(tensor, src, group=self.group) + return tensor + + def all_reduce(self, tensor: torch.Tensor, op: Union[str, ReduceOp] = "sum") -> torch.Tensor: + op = self._convert_to_native_op(op) + dist.all_reduce(tensor, op=op, group=self.group) + return tensor + + def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, ReduceOp] = "sum") -> torch.Tensor: + op = self._convert_to_native_op(op) + dist.reduce(tensor, dst, op=op, group=self.group) + return tensor + + def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]: + dist.all_gather(tensor_list, tensor, group=self.group) + return tensor_list + + def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]: + dist.gather(tensor, gather_list, dst, group=self.group) + return gather_list + + def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor: + dist.scatter(tensor, scatter_list, src, group=self.group) + return tensor + + def reduce_scatter( + self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, ReduceOp] = "sum" + ) -> torch.Tensor: + op = self._convert_to_native_op(op) + dist.reduce_scatter(output, input_list, op=op, group=self.group) + return output + + def all_to_all( + self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor] + ) -> List[torch.Tensor]: + dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group) + return output_tensor_list + + def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None: + dist.send(tensor, dst, tag=tag, group=self.group) + + def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor: + dist.recv(tensor, src, tag=tag, group=self.group) + return tensor + + def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]: + dist.all_gather_object(object_list, obj, group=self.group) + return object_list + + def broadcast_object_list( + self, object_list: List[Any], src: int, device: Optional[torch.device] = None + ) -> List[Any]: + kwargs = {} + if _TORCH_GREATER_EQUAL_1_10: + kwargs["device"] = device + dist.broadcast_object_list(object_list, src, group=self.group, **kwargs) + return object_list + + def gather_object(self, obj: Any, object_gather_list: List[Any], dst: int = 0) -> List[Any]: + dist.gather_object(obj, object_gather_list, dst, group=self.group) + return object_gather_list + + def scatter_object_list( + self, scatter_object_output_list: List[Any], scatter_object_input_list: List[Any], src: int = 0 + ) -> List[Any]: + dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) + return scatter_object_output_list + + def barrier(self, device_ids: Optional[List[int]] = None) -> None: + dist.barrier(group=self.group, device_ids=device_ids) + + def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None: + dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) + + def setup( + self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any + ) -> Self: # type: ignore[valid-type] + if self.is_initialized(): + return self + # maybe set addr + set_addr = False + addr_key = "MASTER_ADDR" + if main_address is not None and addr_key not in os.environ: + os.environ[addr_key] = main_address + set_addr = True + # maybe set port + set_port = False + port_key = "MASTER_PORT" + if main_port is not None and port_key not in os.environ: + os.environ[port_key] = str(main_port) + set_port = True + # this will `init_group` + super().setup(**kwargs) + # cleanup + if set_addr: + os.environ.pop("MASTER_ADDR", None) + if set_port: + os.environ.pop("MASTER_PORT", None) + return self + + def teardown(self) -> Self: # type: ignore[valid-type] + super().teardown() + + @classmethod + def is_available(cls) -> bool: + return dist.is_available() + + @classmethod + def is_initialized(cls) -> bool: + return dist.is_initialized() + + @classmethod + def init_group(cls, **kwargs: Any) -> None: + dist.init_process_group(**kwargs) + + @classmethod + def new_group(cls, **kwargs: Any) -> CollectibleGroup: + return dist.new_group(**kwargs) + + @classmethod + def destroy_group(cls, group: CollectibleGroup) -> None: + dist.destroy_process_group(group) + + @classmethod + def _convert_to_native_op(cls, op: Union[str, ReduceOp]) -> ReduceOp: + if isinstance(op, ReduceOp): + return op + if not isinstance(op, str): + raise ValueError(f"op {op!r} should be a `str` or `ReduceOp`") + op = op.upper() + value = getattr(ReduceOp, op, None) + if value is None: + raise ValueError(f"op {op!r} is not a member of `ReduceOp`") + return value diff --git a/src/lightning_lite/strategies/ddp.py b/src/lightning_lite/strategies/ddp.py index bd229be91934b..1a687a861d663 100644 --- a/src/lightning_lite/strategies/ddp.py +++ b/src/lightning_lite/strategies/ddp.py @@ -17,11 +17,11 @@ import torch import torch.distributed from torch import Tensor -from torch.distributed.constants import default_pg_timeout from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel from lightning_lite.accelerators.accelerator import Accelerator +from lightning_lite.plugins.collectives.torch_collective import default_pg_timeout from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO from lightning_lite.plugins.precision import Precision diff --git a/src/lightning_lite/strategies/ddp_spawn.py b/src/lightning_lite/strategies/ddp_spawn.py index def19d4ac0f24..5e481d4db83e5 100644 --- a/src/lightning_lite/strategies/ddp_spawn.py +++ b/src/lightning_lite/strategies/ddp_spawn.py @@ -17,12 +17,12 @@ import torch import torch.distributed from torch import Tensor -from torch.distributed.constants import default_pg_timeout from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel from typing_extensions import Literal from lightning_lite.accelerators.accelerator import Accelerator +from lightning_lite.plugins.collectives.torch_collective import default_pg_timeout from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO from lightning_lite.plugins.precision import Precision diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index 86da43ec41ec5..f320df4776224 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -17,12 +17,14 @@ import torch from lightning_utilities.core.imports import module_available -from torch.distributed.constants import default_pg_timeout from torch.nn import Module from torch.optim import Optimizer from lightning_lite.accelerators import Accelerator -from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision +from lightning_lite.plugins.collectives.torch_collective import default_pg_timeout +from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment +from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.plugins.precision.precision import Precision from lightning_lite.strategies import DDPSpawnStrategy from lightning_lite.strategies.ddp import DDPStrategy from lightning_lite.utilities.enums import PrecisionType diff --git a/src/lightning_lite/strategies/parallel.py b/src/lightning_lite/strategies/parallel.py index 2036c7943049b..40acc18f41342 100644 --- a/src/lightning_lite/strategies/parallel.py +++ b/src/lightning_lite/strategies/parallel.py @@ -25,7 +25,8 @@ from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO from lightning_lite.plugins.precision import Precision from lightning_lite.strategies.strategy import Strategy -from lightning_lite.utilities.distributed import all_gather_ddp_if_available, ReduceOp +from lightning_lite.utilities.distributed import all_gather_ddp_if_available +from lightning_lite.utilities.types import ReduceOp class ParallelStrategy(Strategy, ABC): diff --git a/src/lightning_lite/strategies/strategy.py b/src/lightning_lite/strategies/strategy.py index adb9f0fb728e3..ae6db20c067de 100644 --- a/src/lightning_lite/strategies/strategy.py +++ b/src/lightning_lite/strategies/strategy.py @@ -28,9 +28,8 @@ from lightning_lite.plugins.precision import Precision from lightning_lite.strategies.launchers.base import _Launcher from lightning_lite.utilities.apply_func import move_data_to_device -from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.optimizer import optimizer_to_device -from lightning_lite.utilities.types import _PATH, Optimizable +from lightning_lite.utilities.types import _PATH, Optimizable, ReduceOp TBroadcast = TypeVar("TBroadcast") TReduce = TypeVar("TReduce") diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index 80165777814ac..c2ad15865c058 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -31,9 +31,8 @@ from lightning_lite.strategies.strategy import TBroadcast from lightning_lite.utilities.apply_func import apply_to_collection from lightning_lite.utilities.data import has_len -from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.rank_zero import rank_zero_only -from lightning_lite.utilities.types import _PATH +from lightning_lite.utilities.types import _PATH, ReduceOp if TYPE_CHECKING and _XLA_AVAILABLE: from torch_xla.distributed.parallel_loader import MpDeviceLoader diff --git a/src/lightning_lite/utilities/__init__.py b/src/lightning_lite/utilities/__init__.py index 4237b5c23a405..63e90494928da 100644 --- a/src/lightning_lite/utilities/__init__.py +++ b/src/lightning_lite/utilities/__init__.py @@ -14,7 +14,6 @@ """General utilities.""" from lightning_lite.utilities.apply_func import move_data_to_device # noqa: F401 -from lightning_lite.utilities.distributed import AllGatherGrad # noqa: F401 from lightning_lite.utilities.enums import _AcceleratorType, _StrategyType, AMPType, LightningEnum # noqa: F401 # TODO(lite): Avoid importing protected attributes in `__init__.py` files diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index ce1f27e82d05d..9e7ea0142ab0b 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -10,14 +10,12 @@ from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.utilities.imports import _HPU_AVAILABLE from lightning_lite.utilities.rank_zero import rank_zero_info +from lightning_lite.utilities.types import ReduceOp if torch.distributed.is_available(): - from torch.distributed import group, ReduceOp + from torch.distributed import group else: - class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153) - SUM = None - class group: # type: ignore WORLD = None diff --git a/src/lightning_lite/utilities/types.py b/src/lightning_lite/utilities/types.py index ed04dd9db2950..c8037616cb4af 100644 --- a/src/lightning_lite/utilities/types.py +++ b/src/lightning_lite/utilities/types.py @@ -25,6 +25,13 @@ _PARAMETERS = Iterator[torch.nn.Parameter] +if torch.distributed.is_available(): + from torch.distributed import ProcessGroup, ReduceOp +else: + ProcessGroup = Any # type: ignore[assignment,misc] + ReduceOp = object # type: ignore[assignment,misc] # we are using isinstance check once + + _DictKey = TypeVar("_DictKey") @@ -39,6 +46,15 @@ def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: ... +@runtime_checkable +class CollectibleGroup(Protocol): + def size(self) -> int: + ... + + def rank(self) -> int: + ... + + # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index 3099abb0b9a8a..686a0e15d017f 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -9,9 +9,9 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.optimizer import optimizers_to_device from lightning_lite.utilities.seed import reset_seed +from lightning_lite.utilities.types import ReduceOp from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 5f28908341a23..d68c1f3da9afa 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -30,12 +30,14 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.plugins.collectives.torch_collective import default_pg_timeout from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device from lightning_lite.utilities.distributed import group as _group -from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.distributed import init_dist_connection, sync_ddp_if_available from lightning_lite.utilities.optimizer import optimizers_to_device from lightning_lite.utilities.seed import reset_seed +from lightning_lite.utilities.types import ReduceOp from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase @@ -58,11 +60,6 @@ if _TORCH_GREATER_EQUAL_1_10 and torch.distributed.is_available(): from torch.distributed.algorithms.model_averaging.averagers import ModelAverager -if torch.distributed.is_available(): - from torch.distributed.constants import default_pg_timeout -else: - default_pg_timeout = timedelta(seconds=1800) - log = logging.getLogger(__name__) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 793c8155d01cb..345928bdc8b6c 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -25,10 +25,12 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.plugins.collectives.torch_collective import default_pg_timeout from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device from lightning_lite.utilities.distributed import group as _group -from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.distributed import init_dist_connection, sync_ddp_if_available from lightning_lite.utilities.optimizer import optimizers_to_device +from lightning_lite.utilities.types import ReduceOp from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward @@ -42,11 +44,6 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep -if torch.distributed.is_available(): - from torch.distributed.constants import default_pg_timeout -else: - default_pg_timeout = timedelta(seconds=1800) - log = logging.getLogger(__name__) _DDP_FORK_ALIASES = ( diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index ff982b876dde0..3c095dcc01403 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -22,9 +22,10 @@ from lightning_lite.plugins import CheckpointIO, ClusterEnvironment from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device from lightning_lite.utilities.distributed import group as _group -from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.distributed import init_dist_connection, sync_ddp_if_available from lightning_lite.utilities.optimizer import optimizers_to_device from lightning_lite.utilities.seed import reset_seed +from lightning_lite.utilities.types import ProcessGroup, ReduceOp from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin @@ -36,7 +37,7 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only -from pytorch_lightning.utilities.types import ProcessGroup, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT _distributed_available = torch.distributed.is_available() _fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available diff --git a/src/pytorch_lightning/strategies/horovod.py b/src/pytorch_lightning/strategies/horovod.py index 4009e09655e11..ea2bcd029c809 100644 --- a/src/pytorch_lightning/strategies/horovod.py +++ b/src/pytorch_lightning/strategies/horovod.py @@ -23,7 +23,7 @@ from lightning_lite.plugins import CheckpointIO from lightning_lite.utilities.distributed import distributed_available from lightning_lite.utilities.distributed import group as dist_group -from lightning_lite.utilities.distributed import ReduceOp +from lightning_lite.utilities.types import ReduceOp from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 76c505a0a3ed2..88757f804f09e 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -26,9 +26,8 @@ from lightning_lite.plugins import CheckpointIO, XLACheckpointIO from lightning_lite.plugins.environments import XLAEnvironment from lightning_lite.utilities.data import has_len -from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.optimizer import optimizers_to_device -from lightning_lite.utilities.types import _PATH +from lightning_lite.utilities.types import _PATH, ReduceOp from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py index dc5c81f2a8919..62aeee50fd643 100644 --- a/src/pytorch_lightning/utilities/__init__.py +++ b/src/pytorch_lightning/utilities/__init__.py @@ -16,7 +16,8 @@ import numpy from lightning_lite.utilities import move_data_to_device # noqa: F401 -from lightning_lite.utilities import AllGatherGrad, AMPType, LightningEnum # noqa: F401 +from lightning_lite.utilities import AMPType, LightningEnum # noqa: F401 +from lightning_lite.utilities.distributed import AllGatherGrad # noqa: F401 from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401 from pytorch_lightning.utilities.grads import grad_norm # noqa: F401 from pytorch_lightning.utilities.imports import ( # noqa: F401 diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index cc9068bd81cc7..f7a8942f503bd 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -27,12 +27,7 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable -from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau - -if torch.distributed.is_available(): - from torch._C._distributed_c10d import ProcessGroup -else: - ProcessGroup = Any # type: ignore[assignment,misc] +from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] diff --git a/tests/tests_lite/plugins/collectives/__init__.py b/tests/tests_lite/plugins/collectives/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_lite/plugins/collectives/test_single_device_collective.py b/tests/tests_lite/plugins/collectives/test_single_device_collective.py new file mode 100644 index 0000000000000..f9c42eb74322a --- /dev/null +++ b/tests/tests_lite/plugins/collectives/test_single_device_collective.py @@ -0,0 +1,29 @@ +from unittest import mock + +import pytest + +from lightning_lite.plugins.collectives import SingleDeviceCollective + + +def test_can_instantiate_without_args(): + SingleDeviceCollective() + + +def test_create_group(): + collective = SingleDeviceCollective() + assert collective.is_available() + assert collective.is_initialized() + + with pytest.raises(RuntimeError, match=r"SingleDeviceCollective` does not own a group"): + _ = collective.group + + with mock.patch( + "lightning_lite.plugins.collectives.single_device_collective.SingleDeviceCollective.new_group" + ) as new_mock: + collective.create_group(arg1=15, arg3=10) + + group_kwargs = {"arg3": 10, "arg1": 15} + new_mock.assert_called_once_with(**group_kwargs) + + with mock.patch("lightning_lite.plugins.collectives.single_device_collective.SingleDeviceCollective.destroy_group"): + collective.teardown() diff --git a/tests/tests_lite/plugins/collectives/test_torch_collective.py b/tests/tests_lite/plugins/collectives/test_torch_collective.py new file mode 100644 index 0000000000000..28fde2c18a75e --- /dev/null +++ b/tests/tests_lite/plugins/collectives/test_torch_collective.py @@ -0,0 +1,245 @@ +import datetime +import os +from functools import partial +from unittest import mock + +import pytest +import torch +from tests_lite.helpers.runif import RunIf + +from lightning_lite.accelerators import CPUAccelerator +from lightning_lite.plugins.collectives import TorchCollective +from lightning_lite.plugins.environments import LightningEnvironment +from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy +from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11 + +if TorchCollective.is_available(): + from torch.distributed import ReduceOp +else: + ReduceOp = mock.Mock() + +skip_distributed_unavailable = pytest.mark.skipif( + not TorchCollective.is_available(), reason="torch.distributed unavailable" +) + +PASSED_TENSOR = mock.Mock() +PASSED_OBJECT = mock.Mock() + + +@pytest.fixture(autouse=True) +def check_destroy_group(): + with mock.patch( + "lightning_lite.plugins.collectives.torch_collective.TorchCollective.new_group", + wraps=TorchCollective.new_group, + ) as mock_new, mock.patch( + "lightning_lite.plugins.collectives.torch_collective.TorchCollective.destroy_group", + wraps=TorchCollective.destroy_group, + ) as mock_destroy: + yield + assert ( + mock_new.call_count == mock_destroy.call_count + ), "new_group and destroy_group should be called the same number of times" + if TorchCollective.is_available(): + assert not TorchCollective.is_initialized() + + +@pytest.mark.parametrize( + ["fn_name", "kwargs", "return_key"], + [ + ("send", {"tensor": PASSED_TENSOR, "dst": 0, "tag": 0}, None), + ("recv", {"tensor": PASSED_TENSOR, "src": 0, "tag": 0}, "tensor"), + ("broadcast", {"tensor": PASSED_TENSOR, "src": 0}, "tensor"), + ("all_reduce", {"tensor": PASSED_TENSOR, "op": ReduceOp.SUM}, "tensor"), + ("reduce", {"tensor": PASSED_TENSOR, "dst": 0, "op": ReduceOp.SUM}, "tensor"), + ("all_gather", {"tensor_list": [PASSED_TENSOR], "tensor": PASSED_TENSOR}, "tensor_list"), + ("gather", {"tensor": PASSED_TENSOR, "gather_list": [PASSED_TENSOR], "dst": 0}, "gather_list"), + ("scatter", {"tensor": PASSED_TENSOR, "scatter_list": [PASSED_TENSOR], "src": 0}, "tensor"), + ("reduce_scatter", {"output": PASSED_TENSOR, "input_list": [PASSED_TENSOR], "op": ReduceOp.SUM}, "output"), + ( + "all_to_all", + {"output_tensor_list": [PASSED_TENSOR], "input_tensor_list": [PASSED_TENSOR]}, + "output_tensor_list", + ), + ("barrier", {"device_ids": [0]}, None), + ("all_gather_object", {"object_list": [PASSED_OBJECT], "obj": PASSED_OBJECT}, "object_list"), + pytest.param( + "broadcast_object_list", + {"object_list": [PASSED_OBJECT], "src": 0}, + "object_list", + marks=RunIf(max_torch="1.10"), + ), + pytest.param( + "broadcast_object_list", + {"object_list": [PASSED_OBJECT], "src": 0, "device": torch.device("cpu")}, + "object_list", + marks=RunIf(min_torch="1.10"), + ), + ( + "gather_object", + {"obj": PASSED_OBJECT, "object_gather_list": [PASSED_OBJECT], "dst": 0}, + "object_gather_list", + ), + ( + "scatter_object_list", + {"scatter_object_output_list": [PASSED_OBJECT], "scatter_object_input_list": [PASSED_OBJECT], "src": 0}, + "scatter_object_output_list", + ), + ("monitored_barrier", {"timeout": datetime.timedelta(seconds=1), "wait_all_ranks": False}, None), + ], +) +@skip_distributed_unavailable +def test_collective_calls_with_created_group(fn_name, kwargs, return_key): + collective = TorchCollective() + with mock.patch("torch.distributed.init_process_group"): + collective.setup() + with mock.patch("torch.distributed.new_group"): + collective.create_group() + fn = getattr(collective, fn_name) + with mock.patch(f"torch.distributed.{fn_name}", autospec=True) as mock_call: + result = fn(**kwargs) + mock_call.assert_called_once_with(**kwargs, group=collective.group) + if return_key is not None: + assert result == kwargs[return_key] + + with mock.patch("torch.distributed.destroy_process_group"): + collective.teardown() + + +@skip_distributed_unavailable +def test_convert_ops(): + # Test regular names + assert TorchCollective._convert_to_native_op("band") == ReduceOp.BAND + assert TorchCollective._convert_to_native_op("bor") == ReduceOp.BOR + assert TorchCollective._convert_to_native_op("bxor") == ReduceOp.BXOR + assert TorchCollective._convert_to_native_op("max") == ReduceOp.MAX + assert TorchCollective._convert_to_native_op("min") == ReduceOp.MIN + assert TorchCollective._convert_to_native_op("product") == ReduceOp.PRODUCT + assert TorchCollective._convert_to_native_op("sum") == ReduceOp.SUM + # Test we are passing through native ops without change + assert TorchCollective._convert_to_native_op(ReduceOp.BAND) == ReduceOp.BAND + assert TorchCollective._convert_to_native_op(ReduceOp.BOR) == ReduceOp.BOR + assert TorchCollective._convert_to_native_op(ReduceOp.BXOR) == ReduceOp.BXOR + assert TorchCollective._convert_to_native_op(ReduceOp.MAX) == ReduceOp.MAX + assert TorchCollective._convert_to_native_op(ReduceOp.MIN) == ReduceOp.MIN + assert TorchCollective._convert_to_native_op(ReduceOp.PRODUCT) == ReduceOp.PRODUCT + assert TorchCollective._convert_to_native_op(ReduceOp.SUM) == ReduceOp.SUM + # Test we are handling different casing properly + assert TorchCollective._convert_to_native_op("BOR") == ReduceOp.BOR + assert TorchCollective._convert_to_native_op("BoR") == ReduceOp.BOR + + # AVG is very recent! + if _TORCH_GREATER_EQUAL_1_11: + assert TorchCollective._convert_to_native_op("avg") == ReduceOp.AVG + + # Test invalid type + with pytest.raises(ValueError, match="op 1 should be a `str` or `ReduceOp`"): + TorchCollective._convert_to_native_op(1) + + # Test invalid string + with pytest.raises(ValueError, match="op 'INVALID' is not a member of `ReduceOp`"): + TorchCollective._convert_to_native_op("invalid") + + +@skip_distributed_unavailable +@mock.patch.dict(os.environ, {}, clear=True) +def test_repeated_create_and_destroy(): + collective = TorchCollective() + with mock.patch("torch.distributed.init_process_group"): + collective.setup(main_address="foo", main_port=123) + + assert not os.environ + + with mock.patch("torch.distributed.new_group") as new_mock: + collective.create_group() + new_mock.assert_called_once() + + with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"): + collective.create_group() + + with mock.patch("torch.distributed.destroy_process_group") as destroy_mock: + collective.teardown() + destroy_mock.assert_called_once() + + assert not os.environ + + with pytest.raises(RuntimeError, match="TorchCollective` does not own a group to destroy"): + collective.teardown() + destroy_mock.assert_called_once_with(new_mock.return_value) + assert collective._group is None + + with mock.patch("torch.distributed.new_group"), mock.patch("torch.distributed.destroy_process_group"): + # check we can create_group again. also chaining + collective.create_group().teardown() + + +def collective_launch(fn, parallel_devices, num_groups=1): + strategy = DDPSpawnStrategy( + accelerator=CPUAccelerator(), parallel_devices=parallel_devices, cluster_environment=LightningEnvironment() + ) + launcher = _MultiProcessingLauncher(strategy=strategy) + collectives = [TorchCollective() for _ in range(num_groups)] + wrapped = partial(wrap_launch_function, fn, strategy, collectives[0]) + return launcher.launch(wrapped, strategy, *collectives) + + +def wrap_launch_function(fn, strategy, collective, *args, **kwargs): + strategy._set_world_ranks() + collective.setup( + world_size=strategy.num_processes, + main_address="localhost", + backend="gloo", + rank=strategy.global_rank, + ) + return fn(*args, **kwargs) + + +def _test_distributed_collectives_fn(strategy, collective): + collective.create_group() + + # all_gather + tensor_list = [torch.zeros(2, dtype=torch.long) for _ in range(strategy.num_processes)] + this = torch.arange(2, dtype=torch.long) + 2 * strategy.global_rank + out = collective.all_gather(tensor_list, this) + expected = torch.arange(2 * strategy.num_processes).split(2) + torch.testing.assert_close(tuple(out), expected) + + # reduce + this = torch.tensor(strategy.global_rank + 1) + out = collective.reduce(this, dst=0, op="max") + expected = torch.tensor(strategy.num_processes) if strategy.global_rank == 0 else this + torch.testing.assert_close(out, expected) + + # all_reduce + this = torch.tensor(strategy.global_rank + 1) + out = collective.all_reduce(this, op="min") + expected = torch.tensor(1) + torch.testing.assert_close(out, expected) + + collective.teardown() + + +@skip_distributed_unavailable +@pytest.mark.parametrize("n", (1, 2)) +def test_collectives_distributed(n): + collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n) + + +def _test_two_groups(strategy, left_collective, right_collective): + left_collective.create_group(ranks=[0, 1]) + right_collective.create_group(ranks=[1, 2]) + + if strategy.global_rank in (0, 1): + tensor = torch.tensor(strategy.global_rank) + left_collective.all_reduce(tensor) + assert tensor == 1 + right_collective.barrier() + if right_collective.rank >= 0: + tensor = torch.tensor(strategy.global_rank) + right_collective.all_reduce(tensor) + assert tensor == 3 + + +@skip_distributed_unavailable +def test_two_groups(): + collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2) diff --git a/tests/tests_pytorch/utilities/test_all_gather_grad.py b/tests/tests_pytorch/utilities/test_all_gather_grad.py index 2a74ef85f7e74..169825e8c699a 100644 --- a/tests/tests_pytorch/utilities/test_all_gather_grad.py +++ b/tests/tests_pytorch/utilities/test_all_gather_grad.py @@ -15,7 +15,7 @@ import numpy as np import torch -from lightning_lite.utilities import AllGatherGrad +from lightning_lite.utilities.distributed import AllGatherGrad from lightning_lite.utilities.seed import seed_everything from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel