diff --git a/pytorch_lightning/plugins/collective/__init__.py b/pytorch_lightning/plugins/collective/__init__.py new file mode 100644 index 0000000000000..9b4b7c3260d7f --- /dev/null +++ b/pytorch_lightning/plugins/collective/__init__.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.collective.collective_plugin import CollectivePlugin # noqa: F401 +from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401 +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 +from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401 diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py new file mode 100644 index 0000000000000..2a2d6f0020144 --- /dev/null +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Union + +import torch + +from pytorch_lightning.utilities.distributed import ReduceOp + + +class CollectivePlugin(ABC): + """Interface for collective functions. + + Lightning collective supports communications between multiple processes and multiple nodes, provides routines such + as barrier, broadcast, all_gather, and reduce + + .. note:: + This API is experimental/in-beta and subject to change + """ + + @abstractmethod + def barrier(self, name: Optional[str] = None) -> None: + """Synchronizes all processes which blocks processes until the whole group enters this function. + + Args: + name: a str pass into barrier. Only torch xla respect this param + """ + + @abstractmethod + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes. + + Args: + obj: the object to broadcast + src: source rank. + """ + + @abstractmethod + def all_gather( + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[List[torch.Tensor], torch.Tensor]: + """Perform a all_gather on all processes. + + Args: + tensor: the tensor to all_gather + process_group: the process group to gather results from + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Returns: a tensor (torch distributed) or a list of tensor (horovod) + """ + + @abstractmethod + def reduce( + self, + tensor: Union[torch.Tensor, Any], + process_group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + process_group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py new file mode 100644 index 0000000000000..0ed5924809ca2 --- /dev/null +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -0,0 +1,93 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import CollectivePlugin +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as dist_group +from pytorch_lightning.utilities.distributed import ReduceOp + +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + + +class HorovodCollective(CollectivePlugin): + """Collective interface for Horovod training type plugins.""" + + def __init__( + self, + on_gpu: bool = False, + local_rank: int = 0, + ) -> None: + self.on_gpu = on_gpu + self.local_rank = local_rank + + def _join(self) -> None: + """Horovod function that indicates that the rank finished processing data. + + All ranks that did not call join() continue to process allreduce operations. This function blocks the Python + thread until all ranks join. + """ + if self.on_gpu: + hvd.join(self.local_rank) + else: + hvd.join() + + def barrier(self, *args: Any, **kwargs: Any) -> None: + if distributed_available(): + self._join() + + def broadcast(self, obj: object, src: int = 0) -> object: + return hvd.broadcast_object(obj, src) + + def all_gather( + self, result: Union[torch.Tensor], process_group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False + ) -> List[torch.Tensor]: + if process_group is not None and process_group != dist_group.WORLD: + raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") + + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + self._join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result + + def reduce( + self, + tensor: Union[torch.Tensor, Any], + process_group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: + if process_group is not None: + raise ValueError( + "Horovod does not support allreduce using a subcommunicator at this time. Unset `process_group`." + ) + + if reduce_op in (None, "avg", "mean"): + reduce_op = hvd.Average + elif reduce_op in ("sum", ReduceOp.SUM): + reduce_op = hvd.Sum + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + self._join() + return hvd.allreduce(tensor, op=reduce_op) diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py new file mode 100644 index 0000000000000..4c8543663b4c0 --- /dev/null +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import CollectivePlugin + + +class SingleDeviceCollective(CollectivePlugin): + """Collective interface for single device training type plugins.""" + + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + return obj + + def all_gather( + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False + ) -> torch.Tensor: + return tensor + + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + return tensor diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py new file mode 100644 index 0000000000000..83b17fe02ecd9 --- /dev/null +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -0,0 +1,108 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.plugins.collective import CollectivePlugin +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available +from pytorch_lightning.utilities.distributed import group as dist_group +from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.types import _METRIC_COLLECTION + + +class TorchCollective(CollectivePlugin): + """Collective interfaces for PyTorch. + + Mainly used by DDP, DDPSpawn, DP and DDP2. + """ + + def __init__( + self, + local_reduce: bool = False, + rank: Optional[int] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device_id: Optional[int] = None, + ) -> None: + """ + Note: + DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False + DP run reduce in on node, local_reduce = True + DDP2 behaves like DP in one node, local_reduce = True + + local_reduce set in Plugins.setup() functions + """ + self.local_reduce = local_reduce + self.rank = rank + self.device = device + self.device_id = device_id + + def barrier(self, *args: Any, **kwargs: Any) -> None: + if not distributed_available(): + return + if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self.device_id) + else: + torch.distributed.barrier() + + def broadcast(self, obj: Any, src: int = 0) -> Any: + if not distributed_available(): + return obj + obj = [obj] + if self.rank != 0: + obj = [None] * len(obj) + broadcast_object_list(obj, src, group=dist_group.WORLD) + return obj[0] + + def all_gather( + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False + ) -> torch.Tensor: + return all_gather_ddp_if_available(tensor, group=process_group, sync_grads=sync_grads) + + def reduce( + self, + tensor: Union[torch.Tensor, _METRIC_COLLECTION], + process_group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, _METRIC_COLLECTION]: + """Reduces the given tensor (e.g. across GPUs/processes) + + If local_reduce = True (dp and ddp2), reduces tensor from all local processes. + + If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes + + Args: + tensor: the tensor to sync and reduce + process_group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if self.local_reduce: + + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) + + return apply_to_collection(tensor, torch.Tensor, mean) + + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, process_group, reduce_op=reduce_op) + return tensor diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py new file mode 100644 index 0000000000000..4b5238758b035 --- /dev/null +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -0,0 +1,99 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import CollectivePlugin +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _TPU_AVAILABLE: + import torch_xla.core.xla_env_vars as xenv + import torch_xla.core.xla_model as xm + from torch_xla.core.xla_model import rendezvous + + +class TPUCollective(CollectivePlugin): + """Collective interface for TPUSpawning training type plugins.""" + + def __init__( + self, + device: Union[str, torch.device] = torch.device("xla"), + root_device: torch.device = torch.device("xla"), + world_size: int = 1, + ): + self.device = device + self.root_device = root_device + self.world_size = world_size + + @property + def _is_distributed(self) -> bool: + # HOST_WORLD_SIZE is None outside the xmp.spawn process + return os.getenv(xenv.HOST_WORLD_SIZE, None) is not None and self.world_size != 1 + + def barrier(self, name: Optional[str] = None) -> None: + if self._is_distributed: + rendezvous(name) + + def broadcast(self, obj: object, src: int = 0) -> object: + if not self._is_distributed: + return obj + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def all_gather( + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False + ) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + process_group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return xm.all_gather(tensor) + + def reduce( + self, output: Any, process_group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + ) -> Any: + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.device) + + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output