From 3b67185fc1aa5149a8d5b61061f95460288262f4 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 9 Sep 2021 11:50:12 -0700 Subject: [PATCH 01/11] 2/n Consolidate collective functions - collective base and subclasses --- CHANGELOG.md | 7 ++ .../plugins/collective/__init__.py | 18 +++ .../plugins/collective/collective_plugin.py | 47 ++++++++ .../plugins/collective/horovod_collective.py | 110 ++++++++++++++++++ .../collective/single_device_collective.py | 44 +++++++ .../plugins/collective/torch_collective.py | 95 +++++++++++++++ .../plugins/collective/tpu_collective.py | 80 +++++++++++++ 7 files changed, 401 insertions(+) create mode 100644 pytorch_lightning/plugins/collective/__init__.py create mode 100644 pytorch_lightning/plugins/collective/collective_plugin.py create mode 100644 pytorch_lightning/plugins/collective/horovod_collective.py create mode 100644 pytorch_lightning/plugins/collective/single_device_collective.py create mode 100644 pytorch_lightning/plugins/collective/torch_collective.py create mode 100644 pytorch_lightning/plugins/collective/tpu_collective.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a168db7b486d..0a8e6f1715428 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,7 +120,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `TrainerModelHooksMixin.is_function_implemented` and `TrainerModelHooksMixin.has_arg` ([#10322](https://github.com/PyTorchLightning/pytorch-lightning/pull/10322)) +<<<<<<< HEAD - Removed deprecated `pytorch_lightning.utilities.device_dtype_mixin.DeviceDtypeModuleMixin` in favor of `pytorch_lightning.core.mixins.device_dtype_mixin.DeviceDtypeModuleMixin` ([#10442](https://github.com/PyTorchLightning/pytorch-lightning/pull/10442)) +======= +- Add collective base class and subclasses ([#9414](https://github.com/PyTorchLightning/pytorch-lightning/pull/9414)) + + +### Changed +>>>>>>> f12625596 (2/n Consolidate collective functions - collective base and subclasses) - Removed deprecated `LightningModule.loaded_optimizer_states_dict` property ([#10346](https://github.com/PyTorchLightning/pytorch-lightning/pull/10346)) diff --git a/pytorch_lightning/plugins/collective/__init__.py b/pytorch_lightning/plugins/collective/__init__.py new file mode 100644 index 0000000000000..93bf987e93652 --- /dev/null +++ b/pytorch_lightning/plugins/collective/__init__.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401 +from pytorch_lightning.plugins.collective.single_device_collective import SingleNodeCollective # noqa: F401 +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 +from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 +from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401 diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py new file mode 100644 index 0000000000000..0a63f48f74da0 --- /dev/null +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch + + +class Collective(ABC): + """Base class for collective functions for training type plugins.""" + + @abstractmethod + def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + """Forces all possibly joined processes to wait for each other.""" + + @abstractmethod + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes.""" + + @abstractmethod + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + + @abstractmethod + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py new file mode 100644 index 0000000000000..bc9c4bdcd5bf7 --- /dev/null +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -0,0 +1,110 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.types import _TPU_AVAILABLE + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + from torch_xla.core.xla_model import rendezvous +else: + xm, rendezvous = [None] * 4 + +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + + +class HorovodCollective(Collective): + """Collective interface for Horovod training type plugins.""" + + def __init__( + self, + on_gpu: Optional[bool] = False, + local_rank: Optional[int] = 0, + ): + self._on_gpu = on_gpu + self._local_rank = local_rank + + def join(self): + """Horovod function that indicates that the rank finished processing data. + + All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread + until all ranks join. + """ + if self.on_gpu: + hvd.join(self.local_rank) + else: + hvd.join() + + def barrier(self, name: Optional[str] = None) -> None: + if self.is_distributed: + rendezvous(name) + + def broadcast(self, obj: object, src: int = 0) -> object: + if not self.is_distributed: + return obj + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return self._xm.all_gather(tensor) + + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if group is not None: + raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.") + + if reduce_op in (None, "avg", "mean"): + reduce_op = hvd.Average + elif reduce_op in ("sum", ReduceOp.SUM): + reduce_op = hvd.Sum + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + self.join() + return hvd.allreduce(tensor, op=reduce_op) diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py new file mode 100644 index 0000000000000..84a4c44149869 --- /dev/null +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective + + +class SingleNodeCollective(Collective): + """Collective interface for single device training type plugins.""" + + def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + """Forces all possibly joined processes to wait for each other.""" + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes.""" + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + return tensor + + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ + return tensor diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py new file mode 100644 index 0000000000000..2acaad10c14c1 --- /dev/null +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import ( + all_gather_ddp_if_available, + distributed_available, + ReduceOp, + sync_ddp_if_available, +) +from pytorch_lightning.utilities.types import _METRIC_COLLECTION + + +class TorchCollective(Collective): + """Collective interface for DDP, DDPSpawn, DP and DDP2.""" + + def __init__(self, local_reduce=False): + """.. note:: + + DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False + DP run reduce in on node, local_reduce = True + DDP2 behaves like DP in one node, local_reduce = True + + local_reduce set in Plugins.setup() functions + """ + self.local_reduce = local_reduce + + def barrier(self, *args, **kwargs) -> None: + if not distributed_available(): + return + if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) + else: + torch.distributed.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + if not distributed_available(): + return obj + return self.dist.broadcast(obj) + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def reduce( + self, tensor: _METRIC_COLLECTION, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean" + ) -> torch.Tensor: + """Reduces the given tensor (e.g. across GPUs/processes) + + If local_reduce = True (dp and ddp2), reduces tensor from all local processes. + + If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if self.local_reduce: + + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) + + return apply_to_collection(tensor, torch.Tensor, mean) + + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def reduce_boolean_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op=ReduceOp.SUM) + decision = bool(decision == self.world_size) + return decision diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py new file mode 100644 index 0000000000000..c31df54e2d1d7 --- /dev/null +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -0,0 +1,80 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _TPU_AVAILABLE + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + from torch_xla.core.xla_model import rendezvous +else: + xm, rendezvous = [None] * 4 + + +class TPUCollective(Collective): + """Collective interface for TPU and TPUSpawning training type plugins.""" + + def barrier(self, name: Optional[str] = None) -> None: + if self.is_distributed: + rendezvous(name) + + def broadcast(self, obj: object, src: int = 0) -> object: + if not self.is_distributed: + return obj + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return self._xm.all_gather(tensor) + + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.lightning_module.device) + + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output From 4ddaf0f7e4c9ef53d5653c78215b5799c1d1467e Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 9 Sep 2021 18:06:38 -0700 Subject: [PATCH 02/11] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/collective_plugin.py | 19 +++-- .../plugins/collective/horovod_collective.py | 73 +++++++++---------- .../collective/single_device_collective.py | 6 +- .../plugins/collective/torch_collective.py | 32 +++++--- .../plugins/collective/tpu_collective.py | 8 +- 5 files changed, 80 insertions(+), 58 deletions(-) diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py index 0a63f48f74da0..c3512452883d3 100644 --- a/pytorch_lightning/plugins/collective/collective_plugin.py +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch +from pytorch_lightning.utilities.distributed import ReduceOp + class Collective(ABC): """Base class for collective functions for training type plugins.""" @abstractmethod - def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + def barrier(self, name: Optional[str] = None) -> None: """Forces all possibly joined processes to wait for each other.""" @abstractmethod @@ -29,11 +31,18 @@ def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes.""" @abstractmethod - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + def all_gather( + self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[List[torch.Tensor], torch.Tensor]: """Perform a all_gather on all processes.""" @abstractmethod - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: """Reduces the given tensor (e.g. across GPUs/processes). Args: @@ -42,6 +51,6 @@ def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> **kwargs: plugin-specific keyword arguments """ + @abstractmethod def reduce_boolean_decision(self, decision: bool) -> bool: """Reduce the early stopping decision across all processes.""" - return decision diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index bc9c4bdcd5bf7..19f6254964646 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -11,21 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from pytorch_lightning.plugins.collective import Collective from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.types import _TPU_AVAILABLE - -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - from torch_xla.core.xla_model import rendezvous -else: - xm, rendezvous = [None] * 4 if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -42,48 +36,47 @@ def __init__( self._on_gpu = on_gpu self._local_rank = local_rank - def join(self): + def join(self) -> None: """Horovod function that indicates that the rank finished processing data. All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread until all ranks join. """ - if self.on_gpu: - hvd.join(self.local_rank) + if self._on_gpu: + hvd.join(self._local_rank) else: hvd.join() - def barrier(self, name: Optional[str] = None) -> None: - if self.is_distributed: - rendezvous(name) + def barrier(self, *args: Any, **kwargs: Any) -> None: + if distributed_available(): + self.join() def broadcast(self, obj: object, src: int = 0) -> object: - if not self.is_distributed: - return obj - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) - data = xm.all_gather(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) + obj = hvd.broadcast_object(obj, src) return obj - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - return self._xm.all_gather(tensor) + def all_gather( + self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False + ) -> List[torch.Tensor]: + if group is not None and group != dist_group.WORLD: + raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + self.join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result + + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -108,3 +101,7 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ # sync all processes before reduction self.join() return hvd.allreduce(tensor, op=reduce_op) + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py index 84a4c44149869..7bd34abd6948d 100644 --- a/pytorch_lightning/plugins/collective/single_device_collective.py +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -21,7 +21,7 @@ class SingleNodeCollective(Collective): """Collective interface for single device training type plugins.""" - def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: """Forces all possibly joined processes to wait for each other.""" pass @@ -42,3 +42,7 @@ def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> **kwargs: plugin-specific keyword arguments """ return tensor + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 2acaad10c14c1..2a81d41644969 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -16,22 +16,20 @@ import torch import torch.distributed +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.collective import Collective from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - distributed_available, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.types import _METRIC_COLLECTION class TorchCollective(Collective): """Collective interface for DDP, DDPSpawn, DP and DDP2.""" - def __init__(self, local_reduce=False): + def __init__(self, local_reduce: bool = False, rank=None, device=None): """.. note:: DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False @@ -40,9 +38,11 @@ def __init__(self, local_reduce=False): local_reduce set in Plugins.setup() functions """ - self.local_reduce = local_reduce + self._local_reduce = local_reduce + self._rank = rank + self._device = device - def barrier(self, *args, **kwargs) -> None: + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": @@ -53,15 +53,23 @@ def barrier(self, *args, **kwargs) -> None: def broadcast(self, obj: object, src: int = 0) -> object: if not distributed_available(): return obj - return self.dist.broadcast(obj) + else: + obj = [obj] + if self.rank != 0: + obj = [None] * len(obj) + broadcast_object_list(obj, 0, group=_group.WORLD) + return obj[0] def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """Perform a all_gather on all processes.""" return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) def reduce( - self, tensor: _METRIC_COLLECTION, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean" - ) -> torch.Tensor: + self, + tensor: Union[torch.Tensor, _METRIC_COLLECTION], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, _METRIC_COLLECTION]: """Reduces the given tensor (e.g. across GPUs/processes) If local_reduce = True (dp and ddp2), reduces tensor from all local processes. diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index c31df54e2d1d7..96768aad8397a 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -17,9 +17,9 @@ import torch from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _TPU_AVAILABLE if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -61,7 +61,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra tensor = tensor.unsqueeze(0) return self._xm.all_gather(tensor) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Any: if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.lightning_module.device) @@ -78,3 +78,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ output = output / self.world_size return output + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision From 9766a721418dfeec293ec0acf6ccb0c705a2cd12 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 9 Sep 2021 18:28:22 -0700 Subject: [PATCH 03/11] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/__init__.py | 4 ++-- .../plugins/collective/horovod_collective.py | 10 ++++---- .../collective/single_device_collective.py | 2 +- .../plugins/collective/torch_collective.py | 23 +++++++++++++------ .../plugins/collective/tpu_collective.py | 16 ++++++++++--- 5 files changed, 37 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/plugins/collective/__init__.py b/pytorch_lightning/plugins/collective/__init__.py index 93bf987e93652..ab5ad5c2146ed 100644 --- a/pytorch_lightning/plugins/collective/__init__.py +++ b/pytorch_lightning/plugins/collective/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401 -from pytorch_lightning.plugins.collective.single_device_collective import SingleNodeCollective # noqa: F401 -from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401 +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401 diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index 19f6254964646..571f23f002960 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -31,10 +31,10 @@ class HorovodCollective(Collective): def __init__( self, on_gpu: Optional[bool] = False, - local_rank: Optional[int] = 0, + local_rank: int = 0, ): - self._on_gpu = on_gpu - self._local_rank = local_rank + self.on_gpu = on_gpu + self.local_rank = local_rank def join(self) -> None: """Horovod function that indicates that the rank finished processing data. @@ -42,8 +42,8 @@ def join(self) -> None: All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread until all ranks join. """ - if self._on_gpu: - hvd.join(self._local_rank) + if self.on_gpu: + hvd.join(self.local_rank) else: hvd.join() diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py index 7bd34abd6948d..18010f244c9a3 100644 --- a/pytorch_lightning/plugins/collective/single_device_collective.py +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -18,7 +18,7 @@ from pytorch_lightning.plugins.collective import Collective -class SingleNodeCollective(Collective): +class SingleDeviceCollective(Collective): """Collective interface for single device training type plugins.""" def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 2a81d41644969..546a107992a00 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -29,7 +29,14 @@ class TorchCollective(Collective): """Collective interface for DDP, DDPSpawn, DP and DDP2.""" - def __init__(self, local_reduce: bool = False, rank=None, device=None): + def __init__( + self, + local_reduce: bool = False, + rank: Optional[int] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device_id: Optional[int] = None, + world_size: int = 1, + ): """.. note:: DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False @@ -38,19 +45,21 @@ def __init__(self, local_reduce: bool = False, rank=None, device=None): local_reduce set in Plugins.setup() functions """ - self._local_reduce = local_reduce - self._rank = rank - self._device = device + self.local_reduce = local_reduce + self.rank = rank + self.device = device + self.device_id = device_id + self.world_size = world_size def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) + torch.distributed.barrier(device_ids=self.device_id) else: torch.distributed.barrier() - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: Any, src: int = 0) -> Any: if not distributed_available(): return obj else: @@ -97,7 +106,7 @@ def mean(t: torch.Tensor) -> torch.Tensor: return tensor def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = torch.tensor(int(decision), device=self.device) decision = self.reduce(decision, reduce_op=ReduceOp.SUM) decision = bool(decision == self.world_size) return decision diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index 96768aad8397a..1ec2cd7517778 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -25,12 +25,22 @@ import torch_xla.core.xla_model as xm from torch_xla.core.xla_model import rendezvous else: - xm, rendezvous = [None] * 4 + xm, rendezvous = [None] * 2 class TPUCollective(Collective): """Collective interface for TPU and TPUSpawning training type plugins.""" + def __init__( + self, + device: Union[str, torch.device] = torch.device("xla"), + root_device: torch.device = xm.xla_device(), + world_size: int = xm.xrt_world_size(), + ): + self.device = device + self.root_device = root_device + self.world_size = world_size + def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) @@ -59,11 +69,11 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra """ if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) - return self._xm.all_gather(tensor) + return xm.all_gather(tensor) def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Any: if not isinstance(output, torch.Tensor): - output = torch.tensor(output, device=self.lightning_module.device) + output = torch.tensor(output, device=self.device) _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") From d2f6c53ab3578ebc93e4918a05b9424243d7aed3 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 10 Sep 2021 19:05:24 -0700 Subject: [PATCH 04/11] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/tpu_collective.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index 1ec2cd7517778..ea10c92bd1954 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import os from typing import Any, Optional, Union import torch @@ -22,25 +23,31 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: + import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm from torch_xla.core.xla_model import rendezvous else: - xm, rendezvous = [None] * 2 + raise MisconfigurationException("TPU device does not exist in ") class TPUCollective(Collective): - """Collective interface for TPU and TPUSpawning training type plugins.""" + """Collective interface for TPUSpawning training type plugins.""" def __init__( self, device: Union[str, torch.device] = torch.device("xla"), - root_device: torch.device = xm.xla_device(), - world_size: int = xm.xrt_world_size(), + root_device: torch.device = torch.device("xla"), + world_size: int = 1, ): self.device = device self.root_device = root_device self.world_size = world_size + @property + def is_distributed(self) -> bool: + # HOST_WORLD_SIZE is None outside the xmp.spawn process + return os.getenv(xenv.HOST_WORLD_SIZE, None) is not None and self.world_size != 1 + def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) From cea07e1f1004dc5e065dc02313f78bbe0367d1bf Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 10 Sep 2021 19:30:11 -0700 Subject: [PATCH 05/11] 2/n Consolidate collective functions - collective base and subclasses --- pytorch_lightning/plugins/collective/tpu_collective.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index ea10c92bd1954..6cacd2c688fb7 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -26,8 +26,6 @@ import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm from torch_xla.core.xla_model import rendezvous -else: - raise MisconfigurationException("TPU device does not exist in ") class TPUCollective(Collective): From c1ececd0d3eb51f776bf1fe2b3c66beb95515892 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Sun, 12 Sep 2021 14:03:22 -0700 Subject: [PATCH 06/11] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .../plugins/collective/horovod_collective.py | 4 ++-- .../plugins/collective/torch_collective.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index 571f23f002960..db3f29669a214 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -32,14 +32,14 @@ def __init__( self, on_gpu: Optional[bool] = False, local_rank: int = 0, - ): + ) -> None: self.on_gpu = on_gpu self.local_rank = local_rank def join(self) -> None: """Horovod function that indicates that the rank finished processing data. - All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread + All ranks that did not call join() continue to process allreduce operations. This function blocks the Python thread until all ranks join. """ if self.on_gpu: diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 546a107992a00..f94389afcce1b 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -36,12 +36,12 @@ def __init__( device: Optional[Union[str, torch.device]] = torch.device("cpu"), device_id: Optional[int] = None, world_size: int = 1, - ): - """.. note:: - - DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False - DP run reduce in on node, local_reduce = True - DDP2 behaves like DP in one node, local_reduce = True + ) -> None: + """ + Note: + DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False + DP run reduce in on node, local_reduce = True + DDP2 behaves like DP in one node, local_reduce = True local_reduce set in Plugins.setup() functions """ @@ -84,6 +84,7 @@ def reduce( If local_reduce = True (dp and ddp2), reduces tensor from all local processes. If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes + Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) From 97e2a8da85771e9a5800568f6d3fc54e9d55897e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 Sep 2021 21:04:25 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/collective/horovod_collective.py | 4 ++-- pytorch_lightning/plugins/collective/torch_collective.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index db3f29669a214..3f1398f3be8cf 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -39,8 +39,8 @@ def __init__( def join(self) -> None: """Horovod function that indicates that the rank finished processing data. - All ranks that did not call join() continue to process allreduce operations. This function blocks the Python thread - until all ranks join. + All ranks that did not call join() continue to process allreduce operations. This function blocks the Python + thread until all ranks join. """ if self.on_gpu: hvd.join(self.local_rank) diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index f94389afcce1b..98f8f519cdec1 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -84,7 +84,7 @@ def reduce( If local_reduce = True (dp and ddp2), reduces tensor from all local processes. If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes - + Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) From 763352cc60f00c0ad6aad82f7430a65029d90836 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 22 Sep 2021 14:23:23 -0700 Subject: [PATCH 08/11] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/__init__.py | 2 +- .../plugins/collective/collective_plugin.py | 38 +++++++++++----- .../plugins/collective/horovod_collective.py | 29 +++---------- .../collective/single_device_collective.py | 18 +------- .../plugins/collective/torch_collective.py | 43 ++++++++----------- .../plugins/collective/tpu_collective.py | 14 +++--- 6 files changed, 62 insertions(+), 82 deletions(-) diff --git a/pytorch_lightning/plugins/collective/__init__.py b/pytorch_lightning/plugins/collective/__init__.py index ab5ad5c2146ed..9b4b7c3260d7f 100644 --- a/pytorch_lightning/plugins/collective/__init__.py +++ b/pytorch_lightning/plugins/collective/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401 +from pytorch_lightning.plugins.collective.collective_plugin import CollectivePlugin # noqa: F401 from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401 from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py index c3512452883d3..192b461108082 100644 --- a/pytorch_lightning/plugins/collective/collective_plugin.py +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -19,22 +19,41 @@ from pytorch_lightning.utilities.distributed import ReduceOp -class Collective(ABC): - """Base class for collective functions for training type plugins.""" +class CollectivePlugin(ABC): + """Interface for collective functions for training type plugins. + + Lightning collective supports communications between multiple processes and multiple nodes, provides routines such + as barrier, broadcast, all_gather, and reduce, reduce + """ @abstractmethod def barrier(self, name: Optional[str] = None) -> None: - """Forces all possibly joined processes to wait for each other.""" + """Synchronizes all processes which blocks processes until the whole group enters this function. + + Args: + name: a str pass into barrier. Only torch xla respect this param + """ @abstractmethod def broadcast(self, obj: object, src: int = 0) -> object: - """Broadcasts an object to all processes.""" + """Broadcasts an object to all processes. + + Args: + obj: the object to broadcast + src: source rank. + """ @abstractmethod def all_gather( - self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False ) -> Union[List[torch.Tensor], torch.Tensor]: - """Perform a all_gather on all processes.""" + """Perform a all_gather on all processes. + + Args: + tensor: the tensor to all_gather + process_group: the process group to gather results from + sync_grads: flag that allows users to synchronize gradients for all_gather op + """ @abstractmethod def reduce( @@ -47,10 +66,9 @@ def reduce( Args: tensor: the tensor to sync and reduce + process_group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. *args: plugin-specific positional arguments **kwargs: plugin-specific keyword arguments """ - - @abstractmethod - def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes.""" diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index 3f1398f3be8cf..5e464c695b4b1 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.plugins.collective import CollectivePlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as dist_group @@ -25,18 +25,18 @@ import horovod.torch as hvd -class HorovodCollective(Collective): +class HorovodCollective(CollectivePlugin): """Collective interface for Horovod training type plugins.""" def __init__( self, - on_gpu: Optional[bool] = False, + on_gpu: bool = False, local_rank: int = 0, ) -> None: self.on_gpu = on_gpu self.local_rank = local_rank - def join(self) -> None: + def _join(self) -> None: """Horovod function that indicates that the rank finished processing data. All ranks that did not call join() continue to process allreduce operations. This function blocks the Python @@ -49,7 +49,7 @@ def join(self) -> None: def barrier(self, *args: Any, **kwargs: Any) -> None: if distributed_available(): - self.join() + self._join() def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) @@ -66,7 +66,7 @@ def all_gather( result = result.reshape(1) # sync and gather all - self.join() + self._join() gathered = hvd.allgather(result) gathered_result = list(gathered.split(1, dim=0)) return gathered_result @@ -77,17 +77,6 @@ def reduce( group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", ) -> Union[torch.Tensor, Any]: - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ if group is not None: raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.") @@ -99,9 +88,5 @@ def reduce( raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") # sync all processes before reduction - self.join() + self._join() return hvd.allreduce(tensor, op=reduce_op) - - def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes.""" - return decision diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py index 18010f244c9a3..de7652a094716 100644 --- a/pytorch_lightning/plugins/collective/single_device_collective.py +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -15,34 +15,20 @@ import torch -from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.plugins.collective import CollectivePlugin -class SingleDeviceCollective(Collective): +class SingleDeviceCollective(CollectivePlugin): """Collective interface for single device training type plugins.""" def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - """Forces all possibly joined processes to wait for each other.""" pass def broadcast(self, obj: object, src: int = 0) -> object: - """Broadcasts an object to all processes.""" return obj def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" return tensor def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - """Reduces the given tensor (e.g. across GPUs/processes). - - Args: - tensor: the tensor to sync and reduce - *args: plugin-specific positional arguments - **kwargs: plugin-specific keyword arguments - """ return tensor - - def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes.""" - return decision diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 98f8f519cdec1..5f499418c02c0 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -17,7 +17,7 @@ import torch.distributed from pytorch_lightning.overrides.torch_distributed import broadcast_object_list -from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.plugins.collective import CollectivePlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available @@ -26,8 +26,11 @@ from pytorch_lightning.utilities.types import _METRIC_COLLECTION -class TorchCollective(Collective): - """Collective interface for DDP, DDPSpawn, DP and DDP2.""" +class TorchCollective(CollectivePlugin): + """Collective interfaces for PyTorch. + + Mainly used by DDP, DDPSpawn, DP and DDP2. + """ def __init__( self, @@ -35,7 +38,6 @@ def __init__( rank: Optional[int] = None, device: Optional[Union[str, torch.device]] = torch.device("cpu"), device_id: Optional[int] = None, - world_size: int = 1, ) -> None: """ Note: @@ -49,9 +51,8 @@ def __init__( self.rank = rank self.device = device self.device_id = device_id - self.world_size = world_size - def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + def barrier(self, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": @@ -62,16 +63,16 @@ def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None def broadcast(self, obj: Any, src: int = 0) -> Any: if not distributed_available(): return obj - else: - obj = [obj] - if self.rank != 0: - obj = [None] * len(obj) - broadcast_object_list(obj, 0, group=_group.WORLD) - return obj[0] + obj = [obj] + if self.rank != 0: + obj = [None] * len(obj) + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + def all_gather( + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False + ) -> torch.Tensor: + return all_gather_ddp_if_available(tensor, group=process_group, sync_grads=sync_grads) def reduce( self, @@ -87,9 +88,9 @@ def reduce( Args: tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. + process_group: the process group to reduce + reduce_op: the reduction operation. Defaults to 'mean'. + Can also be a string 'sum' or ReduceOp. Return: reduced value, except when the input was not a tensor the output remains is unchanged @@ -105,9 +106,3 @@ def mean(t: torch.Tensor) -> torch.Tensor: if isinstance(tensor, torch.Tensor): tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.device) - decision = self.reduce(decision, reduce_op=ReduceOp.SUM) - decision = bool(decision == self.world_size) - return decision diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index 6cacd2c688fb7..b90decfdc16e4 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -17,7 +17,7 @@ import torch -from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.plugins.collective import CollectivePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -28,7 +28,7 @@ from torch_xla.core.xla_model import rendezvous -class TPUCollective(Collective): +class TPUCollective(CollectivePlugin): """Collective interface for TPUSpawning training type plugins.""" def __init__( @@ -42,16 +42,16 @@ def __init__( self.world_size = world_size @property - def is_distributed(self) -> bool: + def _is_distributed(self) -> bool: # HOST_WORLD_SIZE is None outside the xmp.spawn process return os.getenv(xenv.HOST_WORLD_SIZE, None) is not None and self.world_size != 1 def barrier(self, name: Optional[str] = None) -> None: - if self.is_distributed: + if self._is_distributed: rendezvous(name) def broadcast(self, obj: object, src: int = 0) -> object: - if not self.is_distributed: + if not self._is_distributed: return obj buffer = io.BytesIO() torch.save(obj, buffer) @@ -93,7 +93,3 @@ def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[U output = output / self.world_size return output - - def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes.""" - return decision From 393adcc2440e4d16f9d2594fbcd268e2df585aa2 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 22 Sep 2021 14:50:14 -0700 Subject: [PATCH 09/11] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/collective_plugin.py | 2 +- .../plugins/collective/horovod_collective.py | 12 +++++++----- .../plugins/collective/single_device_collective.py | 4 +++- .../plugins/collective/torch_collective.py | 8 ++++---- .../plugins/collective/tpu_collective.py | 10 +++++++--- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py index 192b461108082..4672f9b767e61 100644 --- a/pytorch_lightning/plugins/collective/collective_plugin.py +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -59,7 +59,7 @@ def all_gather( def reduce( self, tensor: Union[torch.Tensor, Any], - group: Optional[Any] = None, + process_group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", ) -> Union[torch.Tensor, Any]: """Reduces the given tensor (e.g. across GPUs/processes). diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index 5e464c695b4b1..77755a2addb12 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -56,9 +56,9 @@ def broadcast(self, obj: object, src: int = 0) -> object: return obj def all_gather( - self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False + self, result: Union[torch.Tensor], process_group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False ) -> List[torch.Tensor]: - if group is not None and group != dist_group.WORLD: + if process_group is not None and process_group != dist_group.WORLD: raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") if len(result.shape) == 0: @@ -74,11 +74,13 @@ def all_gather( def reduce( self, tensor: Union[torch.Tensor, Any], - group: Optional[Any] = None, + process_group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", ) -> Union[torch.Tensor, Any]: - if group is not None: - raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.") + if process_group is not None: + raise ValueError( + "Horovod does not support allreduce using a subcommunicator at this time. Unset `process_group`." + ) if reduce_op in (None, "avg", "mean"): reduce_op = hvd.Average diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py index de7652a094716..4c8543663b4c0 100644 --- a/pytorch_lightning/plugins/collective/single_device_collective.py +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -27,7 +27,9 @@ def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None def broadcast(self, obj: object, src: int = 0) -> object: return obj - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + def all_gather( + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False + ) -> torch.Tensor: return tensor def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 5f499418c02c0..83b17fe02ecd9 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -21,7 +21,7 @@ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available -from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.types import _METRIC_COLLECTION @@ -66,7 +66,7 @@ def broadcast(self, obj: Any, src: int = 0) -> Any: obj = [obj] if self.rank != 0: obj = [None] * len(obj) - broadcast_object_list(obj, src, group=_group.WORLD) + broadcast_object_list(obj, src, group=dist_group.WORLD) return obj[0] def all_gather( @@ -77,7 +77,7 @@ def all_gather( def reduce( self, tensor: Union[torch.Tensor, _METRIC_COLLECTION], - group: Optional[Any] = None, + process_group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", ) -> Union[torch.Tensor, _METRIC_COLLECTION]: """Reduces the given tensor (e.g. across GPUs/processes) @@ -104,5 +104,5 @@ def mean(t: torch.Tensor) -> torch.Tensor: return apply_to_collection(tensor, torch.Tensor, mean) if isinstance(tensor, torch.Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + tensor = sync_ddp_if_available(tensor, process_group, reduce_op=reduce_op) return tensor diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index b90decfdc16e4..4b5238758b035 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -62,12 +62,14 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + def all_gather( + self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False + ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) - group: not available with TPUs + process_group: not available with TPUs sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) @@ -76,7 +78,9 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra tensor = tensor.unsqueeze(0) return xm.all_gather(tensor) - def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Any: + def reduce( + self, output: Any, process_group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + ) -> Any: if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.device) From 5f7febf4172a7195e6ef8b0c6887571cd02e5904 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Wed, 22 Sep 2021 18:59:09 -0700 Subject: [PATCH 10/11] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/collective_plugin.py | 9 +++++++-- .../plugins/collective/horovod_collective.py | 3 +-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py index 4672f9b767e61..2a2d6f0020144 100644 --- a/pytorch_lightning/plugins/collective/collective_plugin.py +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -20,10 +20,13 @@ class CollectivePlugin(ABC): - """Interface for collective functions for training type plugins. + """Interface for collective functions. Lightning collective supports communications between multiple processes and multiple nodes, provides routines such - as barrier, broadcast, all_gather, and reduce, reduce + as barrier, broadcast, all_gather, and reduce + + .. note:: + This API is experimental/in-beta and subject to change """ @abstractmethod @@ -53,6 +56,8 @@ def all_gather( tensor: the tensor to all_gather process_group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op + + Returns: a tensor (torch distributed) or a list of tensor (horovod) """ @abstractmethod diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index 77755a2addb12..0ed5924809ca2 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -52,8 +52,7 @@ def barrier(self, *args: Any, **kwargs: Any) -> None: self._join() def broadcast(self, obj: object, src: int = 0) -> object: - obj = hvd.broadcast_object(obj, src) - return obj + return hvd.broadcast_object(obj, src) def all_gather( self, result: Union[torch.Tensor], process_group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False From 464fbf1eb104cac67b8a5f26b296e93d1859df3e Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Mon, 6 Dec 2021 14:40:46 -0800 Subject: [PATCH 11/11] Update CHANGELOG.md --- CHANGELOG.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a8e6f1715428..0a168db7b486d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,14 +120,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `TrainerModelHooksMixin.is_function_implemented` and `TrainerModelHooksMixin.has_arg` ([#10322](https://github.com/PyTorchLightning/pytorch-lightning/pull/10322)) -<<<<<<< HEAD - Removed deprecated `pytorch_lightning.utilities.device_dtype_mixin.DeviceDtypeModuleMixin` in favor of `pytorch_lightning.core.mixins.device_dtype_mixin.DeviceDtypeModuleMixin` ([#10442](https://github.com/PyTorchLightning/pytorch-lightning/pull/10442)) -======= -- Add collective base class and subclasses ([#9414](https://github.com/PyTorchLightning/pytorch-lightning/pull/9414)) - - -### Changed ->>>>>>> f12625596 (2/n Consolidate collective functions - collective base and subclasses) - Removed deprecated `LightningModule.loaded_optimizer_states_dict` property ([#10346](https://github.com/PyTorchLightning/pytorch-lightning/pull/10346))