From ed43ad60a64a830bf7230bcb81b9e29b27279ccd Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 9 Sep 2021 11:50:12 -0700 Subject: [PATCH 1/8] 2/n Consolidate collective functions - collective base and subclasses --- CHANGELOG.md | 3 + .../plugins/collective/__init__.py | 18 +++ .../plugins/collective/collective_plugin.py | 47 ++++++++ .../plugins/collective/horovod_collective.py | 110 ++++++++++++++++++ .../collective/single_device_collective.py | 44 +++++++ .../plugins/collective/torch_collective.py | 95 +++++++++++++++ .../plugins/collective/tpu_collective.py | 80 +++++++++++++ 7 files changed, 397 insertions(+) create mode 100644 pytorch_lightning/plugins/collective/__init__.py create mode 100644 pytorch_lightning/plugins/collective/collective_plugin.py create mode 100644 pytorch_lightning/plugins/collective/horovod_collective.py create mode 100644 pytorch_lightning/plugins/collective/single_device_collective.py create mode 100644 pytorch_lightning/plugins/collective/torch_collective.py create mode 100644 pytorch_lightning/plugins/collective/tpu_collective.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a8b068e5eccdc..453822e27d4b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344)) +- Add collective base class and subclasses ([#9414](https://github.com/PyTorchLightning/pytorch-lightning/pull/9414)) + + ### Changed - `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)). diff --git a/pytorch_lightning/plugins/collective/__init__.py b/pytorch_lightning/plugins/collective/__init__.py new file mode 100644 index 0000000000000..93bf987e93652 --- /dev/null +++ b/pytorch_lightning/plugins/collective/__init__.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401 +from pytorch_lightning.plugins.collective.single_device_collective import SingleNodeCollective # noqa: F401 +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 +from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 +from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401 diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py new file mode 100644 index 0000000000000..0a63f48f74da0 --- /dev/null +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch + + +class Collective(ABC): + """Base class for collective functions for training type plugins.""" + + @abstractmethod + def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + """Forces all possibly joined processes to wait for each other.""" + + @abstractmethod + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes.""" + + @abstractmethod + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + + @abstractmethod + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py new file mode 100644 index 0000000000000..bc9c4bdcd5bf7 --- /dev/null +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -0,0 +1,110 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.types import _TPU_AVAILABLE + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + from torch_xla.core.xla_model import rendezvous +else: + xm, rendezvous = [None] * 4 + +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + + +class HorovodCollective(Collective): + """Collective interface for Horovod training type plugins.""" + + def __init__( + self, + on_gpu: Optional[bool] = False, + local_rank: Optional[int] = 0, + ): + self._on_gpu = on_gpu + self._local_rank = local_rank + + def join(self): + """Horovod function that indicates that the rank finished processing data. + + All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread + until all ranks join. + """ + if self.on_gpu: + hvd.join(self.local_rank) + else: + hvd.join() + + def barrier(self, name: Optional[str] = None) -> None: + if self.is_distributed: + rendezvous(name) + + def broadcast(self, obj: object, src: int = 0) -> object: + if not self.is_distributed: + return obj + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return self._xm.all_gather(tensor) + + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if group is not None: + raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.") + + if reduce_op in (None, "avg", "mean"): + reduce_op = hvd.Average + elif reduce_op in ("sum", ReduceOp.SUM): + reduce_op = hvd.Sum + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + self.join() + return hvd.allreduce(tensor, op=reduce_op) diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py new file mode 100644 index 0000000000000..84a4c44149869 --- /dev/null +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective + + +class SingleNodeCollective(Collective): + """Collective interface for single device training type plugins.""" + + def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + """Forces all possibly joined processes to wait for each other.""" + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes.""" + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + return tensor + + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ + return tensor diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py new file mode 100644 index 0000000000000..2acaad10c14c1 --- /dev/null +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import ( + all_gather_ddp_if_available, + distributed_available, + ReduceOp, + sync_ddp_if_available, +) +from pytorch_lightning.utilities.types import _METRIC_COLLECTION + + +class TorchCollective(Collective): + """Collective interface for DDP, DDPSpawn, DP and DDP2.""" + + def __init__(self, local_reduce=False): + """.. note:: + + DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False + DP run reduce in on node, local_reduce = True + DDP2 behaves like DP in one node, local_reduce = True + + local_reduce set in Plugins.setup() functions + """ + self.local_reduce = local_reduce + + def barrier(self, *args, **kwargs) -> None: + if not distributed_available(): + return + if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) + else: + torch.distributed.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + if not distributed_available(): + return obj + return self.dist.broadcast(obj) + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def reduce( + self, tensor: _METRIC_COLLECTION, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean" + ) -> torch.Tensor: + """Reduces the given tensor (e.g. across GPUs/processes) + + If local_reduce = True (dp and ddp2), reduces tensor from all local processes. + + If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if self.local_reduce: + + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) + + return apply_to_collection(tensor, torch.Tensor, mean) + + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def reduce_boolean_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op=ReduceOp.SUM) + decision = bool(decision == self.world_size) + return decision diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py new file mode 100644 index 0000000000000..c31df54e2d1d7 --- /dev/null +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -0,0 +1,80 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _TPU_AVAILABLE + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + from torch_xla.core.xla_model import rendezvous +else: + xm, rendezvous = [None] * 4 + + +class TPUCollective(Collective): + """Collective interface for TPU and TPUSpawning training type plugins.""" + + def barrier(self, name: Optional[str] = None) -> None: + if self.is_distributed: + rendezvous(name) + + def broadcast(self, obj: object, src: int = 0) -> object: + if not self.is_distributed: + return obj + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return self._xm.all_gather(tensor) + + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.lightning_module.device) + + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output From 387432a4c93bf01daec05ad22d40b79322721b8c Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 9 Sep 2021 18:06:38 -0700 Subject: [PATCH 2/8] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/collective_plugin.py | 19 +++-- .../plugins/collective/horovod_collective.py | 73 +++++++++---------- .../collective/single_device_collective.py | 6 +- .../plugins/collective/torch_collective.py | 32 +++++--- .../plugins/collective/tpu_collective.py | 8 +- 5 files changed, 80 insertions(+), 58 deletions(-) diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py index 0a63f48f74da0..c3512452883d3 100644 --- a/pytorch_lightning/plugins/collective/collective_plugin.py +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch +from pytorch_lightning.utilities.distributed import ReduceOp + class Collective(ABC): """Base class for collective functions for training type plugins.""" @abstractmethod - def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + def barrier(self, name: Optional[str] = None) -> None: """Forces all possibly joined processes to wait for each other.""" @abstractmethod @@ -29,11 +31,18 @@ def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes.""" @abstractmethod - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + def all_gather( + self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[List[torch.Tensor], torch.Tensor]: """Perform a all_gather on all processes.""" @abstractmethod - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: """Reduces the given tensor (e.g. across GPUs/processes). Args: @@ -42,6 +51,6 @@ def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> **kwargs: plugin-specific keyword arguments """ + @abstractmethod def reduce_boolean_decision(self, decision: bool) -> bool: """Reduce the early stopping decision across all processes.""" - return decision diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index bc9c4bdcd5bf7..19f6254964646 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -11,21 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from pytorch_lightning.plugins.collective import Collective from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.types import _TPU_AVAILABLE - -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - from torch_xla.core.xla_model import rendezvous -else: - xm, rendezvous = [None] * 4 if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -42,48 +36,47 @@ def __init__( self._on_gpu = on_gpu self._local_rank = local_rank - def join(self): + def join(self) -> None: """Horovod function that indicates that the rank finished processing data. All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread until all ranks join. """ - if self.on_gpu: - hvd.join(self.local_rank) + if self._on_gpu: + hvd.join(self._local_rank) else: hvd.join() - def barrier(self, name: Optional[str] = None) -> None: - if self.is_distributed: - rendezvous(name) + def barrier(self, *args: Any, **kwargs: Any) -> None: + if distributed_available(): + self.join() def broadcast(self, obj: object, src: int = 0) -> object: - if not self.is_distributed: - return obj - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) - data = xm.all_gather(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) + obj = hvd.broadcast_object(obj, src) return obj - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - return self._xm.all_gather(tensor) + def all_gather( + self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False + ) -> List[torch.Tensor]: + if group is not None and group != dist_group.WORLD: + raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + self.join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result + + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -108,3 +101,7 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ # sync all processes before reduction self.join() return hvd.allreduce(tensor, op=reduce_op) + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py index 84a4c44149869..7bd34abd6948d 100644 --- a/pytorch_lightning/plugins/collective/single_device_collective.py +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -21,7 +21,7 @@ class SingleNodeCollective(Collective): """Collective interface for single device training type plugins.""" - def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None: + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: """Forces all possibly joined processes to wait for each other.""" pass @@ -42,3 +42,7 @@ def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> **kwargs: plugin-specific keyword arguments """ return tensor + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 2acaad10c14c1..2a81d41644969 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -16,22 +16,20 @@ import torch import torch.distributed +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.collective import Collective from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - distributed_available, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.types import _METRIC_COLLECTION class TorchCollective(Collective): """Collective interface for DDP, DDPSpawn, DP and DDP2.""" - def __init__(self, local_reduce=False): + def __init__(self, local_reduce: bool = False, rank=None, device=None): """.. note:: DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False @@ -40,9 +38,11 @@ def __init__(self, local_reduce=False): local_reduce set in Plugins.setup() functions """ - self.local_reduce = local_reduce + self._local_reduce = local_reduce + self._rank = rank + self._device = device - def barrier(self, *args, **kwargs) -> None: + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": @@ -53,15 +53,23 @@ def barrier(self, *args, **kwargs) -> None: def broadcast(self, obj: object, src: int = 0) -> object: if not distributed_available(): return obj - return self.dist.broadcast(obj) + else: + obj = [obj] + if self.rank != 0: + obj = [None] * len(obj) + broadcast_object_list(obj, 0, group=_group.WORLD) + return obj[0] def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """Perform a all_gather on all processes.""" return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) def reduce( - self, tensor: _METRIC_COLLECTION, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean" - ) -> torch.Tensor: + self, + tensor: Union[torch.Tensor, _METRIC_COLLECTION], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, _METRIC_COLLECTION]: """Reduces the given tensor (e.g. across GPUs/processes) If local_reduce = True (dp and ddp2), reduces tensor from all local processes. diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index c31df54e2d1d7..96768aad8397a 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -17,9 +17,9 @@ import torch from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _TPU_AVAILABLE if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -61,7 +61,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra tensor = tensor.unsqueeze(0) return self._xm.all_gather(tensor) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Any: if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.lightning_module.device) @@ -78,3 +78,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ output = output / self.world_size return output + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision From 9f26bf563c11bc5d9d7d98c763fa09617902f82d Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 9 Sep 2021 18:28:22 -0700 Subject: [PATCH 3/8] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/__init__.py | 4 ++-- .../plugins/collective/horovod_collective.py | 10 ++++---- .../collective/single_device_collective.py | 2 +- .../plugins/collective/torch_collective.py | 23 +++++++++++++------ .../plugins/collective/tpu_collective.py | 16 ++++++++++--- 5 files changed, 37 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/plugins/collective/__init__.py b/pytorch_lightning/plugins/collective/__init__.py index 93bf987e93652..ab5ad5c2146ed 100644 --- a/pytorch_lightning/plugins/collective/__init__.py +++ b/pytorch_lightning/plugins/collective/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401 -from pytorch_lightning.plugins.collective.single_device_collective import SingleNodeCollective # noqa: F401 -from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401 +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401 diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index 19f6254964646..571f23f002960 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -31,10 +31,10 @@ class HorovodCollective(Collective): def __init__( self, on_gpu: Optional[bool] = False, - local_rank: Optional[int] = 0, + local_rank: int = 0, ): - self._on_gpu = on_gpu - self._local_rank = local_rank + self.on_gpu = on_gpu + self.local_rank = local_rank def join(self) -> None: """Horovod function that indicates that the rank finished processing data. @@ -42,8 +42,8 @@ def join(self) -> None: All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread until all ranks join. """ - if self._on_gpu: - hvd.join(self._local_rank) + if self.on_gpu: + hvd.join(self.local_rank) else: hvd.join() diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py index 7bd34abd6948d..18010f244c9a3 100644 --- a/pytorch_lightning/plugins/collective/single_device_collective.py +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -18,7 +18,7 @@ from pytorch_lightning.plugins.collective import Collective -class SingleNodeCollective(Collective): +class SingleDeviceCollective(Collective): """Collective interface for single device training type plugins.""" def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 2a81d41644969..546a107992a00 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -29,7 +29,14 @@ class TorchCollective(Collective): """Collective interface for DDP, DDPSpawn, DP and DDP2.""" - def __init__(self, local_reduce: bool = False, rank=None, device=None): + def __init__( + self, + local_reduce: bool = False, + rank: Optional[int] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device_id: Optional[int] = None, + world_size: int = 1, + ): """.. note:: DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False @@ -38,19 +45,21 @@ def __init__(self, local_reduce: bool = False, rank=None, device=None): local_reduce set in Plugins.setup() functions """ - self._local_reduce = local_reduce - self._rank = rank - self._device = device + self.local_reduce = local_reduce + self.rank = rank + self.device = device + self.device_id = device_id + self.world_size = world_size def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) + torch.distributed.barrier(device_ids=self.device_id) else: torch.distributed.barrier() - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: Any, src: int = 0) -> Any: if not distributed_available(): return obj else: @@ -97,7 +106,7 @@ def mean(t: torch.Tensor) -> torch.Tensor: return tensor def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = torch.tensor(int(decision), device=self.device) decision = self.reduce(decision, reduce_op=ReduceOp.SUM) decision = bool(decision == self.world_size) return decision diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index 96768aad8397a..1ec2cd7517778 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -25,12 +25,22 @@ import torch_xla.core.xla_model as xm from torch_xla.core.xla_model import rendezvous else: - xm, rendezvous = [None] * 4 + xm, rendezvous = [None] * 2 class TPUCollective(Collective): """Collective interface for TPU and TPUSpawning training type plugins.""" + def __init__( + self, + device: Union[str, torch.device] = torch.device("xla"), + root_device: torch.device = xm.xla_device(), + world_size: int = xm.xrt_world_size(), + ): + self.device = device + self.root_device = root_device + self.world_size = world_size + def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) @@ -59,11 +69,11 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra """ if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) - return self._xm.all_gather(tensor) + return xm.all_gather(tensor) def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Any: if not isinstance(output, torch.Tensor): - output = torch.tensor(output, device=self.lightning_module.device) + output = torch.tensor(output, device=self.device) _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") From 218188c8c2762e386e7a1c6c2b1b10d1e0a541c0 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 10 Sep 2021 19:05:24 -0700 Subject: [PATCH 4/8] 2/n Consolidate collective functions - collective base and subclasses --- .../plugins/collective/tpu_collective.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index 1ec2cd7517778..ea10c92bd1954 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import os from typing import Any, Optional, Union import torch @@ -22,25 +23,31 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: + import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm from torch_xla.core.xla_model import rendezvous else: - xm, rendezvous = [None] * 2 + raise MisconfigurationException("TPU device does not exist in ") class TPUCollective(Collective): - """Collective interface for TPU and TPUSpawning training type plugins.""" + """Collective interface for TPUSpawning training type plugins.""" def __init__( self, device: Union[str, torch.device] = torch.device("xla"), - root_device: torch.device = xm.xla_device(), - world_size: int = xm.xrt_world_size(), + root_device: torch.device = torch.device("xla"), + world_size: int = 1, ): self.device = device self.root_device = root_device self.world_size = world_size + @property + def is_distributed(self) -> bool: + # HOST_WORLD_SIZE is None outside the xmp.spawn process + return os.getenv(xenv.HOST_WORLD_SIZE, None) is not None and self.world_size != 1 + def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) From e4ac51164aaa4c3e2f1465e59b5120432b22e82a Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Fri, 10 Sep 2021 19:30:11 -0700 Subject: [PATCH 5/8] 2/n Consolidate collective functions - collective base and subclasses --- pytorch_lightning/plugins/collective/tpu_collective.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py index ea10c92bd1954..6cacd2c688fb7 100644 --- a/pytorch_lightning/plugins/collective/tpu_collective.py +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -26,8 +26,6 @@ import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm from torch_xla.core.xla_model import rendezvous -else: - raise MisconfigurationException("TPU device does not exist in ") class TPUCollective(Collective): From 8574b007de31635c1b9fd77c536642b7a412ef70 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Sun, 12 Sep 2021 14:03:22 -0700 Subject: [PATCH 6/8] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .../plugins/collective/horovod_collective.py | 4 ++-- .../plugins/collective/torch_collective.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index 571f23f002960..db3f29669a214 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -32,14 +32,14 @@ def __init__( self, on_gpu: Optional[bool] = False, local_rank: int = 0, - ): + ) -> None: self.on_gpu = on_gpu self.local_rank = local_rank def join(self) -> None: """Horovod function that indicates that the rank finished processing data. - All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread + All ranks that did not call join() continue to process allreduce operations. This function blocks the Python thread until all ranks join. """ if self.on_gpu: diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 546a107992a00..f94389afcce1b 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -36,12 +36,12 @@ def __init__( device: Optional[Union[str, torch.device]] = torch.device("cpu"), device_id: Optional[int] = None, world_size: int = 1, - ): - """.. note:: - - DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False - DP run reduce in on node, local_reduce = True - DDP2 behaves like DP in one node, local_reduce = True + ) -> None: + """ + Note: + DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False + DP run reduce in on node, local_reduce = True + DDP2 behaves like DP in one node, local_reduce = True local_reduce set in Plugins.setup() functions """ @@ -84,6 +84,7 @@ def reduce( If local_reduce = True (dp and ddp2), reduces tensor from all local processes. If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes + Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) From 2379aea76b5245ef69a644c0e054dccf649d1356 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 Sep 2021 21:04:25 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/collective/horovod_collective.py | 4 ++-- pytorch_lightning/plugins/collective/torch_collective.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py index db3f29669a214..3f1398f3be8cf 100644 --- a/pytorch_lightning/plugins/collective/horovod_collective.py +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -39,8 +39,8 @@ def __init__( def join(self) -> None: """Horovod function that indicates that the rank finished processing data. - All ranks that did not call join() continue to process allreduce operations. This function blocks the Python thread - until all ranks join. + All ranks that did not call join() continue to process allreduce operations. This function blocks the Python + thread until all ranks join. """ if self.on_gpu: hvd.join(self.local_rank) diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index f94389afcce1b..98f8f519cdec1 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -84,7 +84,7 @@ def reduce( If local_reduce = True (dp and ddp2), reduces tensor from all local processes. If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes - + Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) From c7fce7c6be673924ccb040e74f9faebbab44e99e Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Sun, 12 Sep 2021 19:01:27 -0700 Subject: [PATCH 8/8] 3/n Consolidate collective functions - Integrate with TTPs --- CHANGELOG.md | 3 + pytorch_lightning/accelerators/accelerator.py | 6 +- pytorch_lightning/callbacks/early_stopping.py | 2 +- .../callbacks/model_checkpoint.py | 10 ++- .../callbacks/xla_stats_monitor.py | 8 +- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/loops/base.py | 4 +- .../plugins/collective/torch_collective.py | 11 ++- .../plugins/training_type/ddp.py | 53 +++--------- .../plugins/training_type/ddp2.py | 24 +----- .../plugins/training_type/ddp_spawn.py | 53 +++--------- .../plugins/training_type/deepspeed.py | 8 +- pytorch_lightning/plugins/training_type/dp.py | 38 +++------ .../plugins/training_type/fully_sharded.py | 6 +- .../plugins/training_type/horovod.py | 80 ++++--------------- .../plugins/training_type/ipu.py | 16 +--- .../plugins/training_type/parallel.py | 20 ++--- .../plugins/training_type/single_device.py | 31 ++----- .../plugins/training_type/tpu_spawn.py | 69 ++++------------ .../training_type/training_type_plugin.py | 30 +------ .../connectors/checkpoint_connector.py | 2 +- .../test_checkpoint_callback_frequency.py | 2 +- tests/core/test_metric_result_integration.py | 6 +- .../data/horovod/train_default_model.py | 4 +- tests/models/test_tpu.py | 6 +- 25 files changed, 133 insertions(+), 361 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 453822e27d4b3..c3a0f68aa8b3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -180,6 +180,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360)) +- Integrate `collective` class with `TrainingTypePlugin` ([#9472](https://github.com/PyTorchLightning/pytorch-lightning/pull/9472)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0a6060ee12304..cfbe84a183218 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -338,7 +338,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: return self.training_type_plugin.lightning_module_state_dict() def barrier(self, name: Optional[str] = None) -> None: - self.training_type_plugin.barrier(name=name) + self.training_type_plugin.collective.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if @@ -348,7 +348,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj: Object to broadcast to all process, usually a tensor or collection of tensors. src: The source rank of which the object will be broadcast from """ - return self.training_type_plugin.broadcast(obj, src) + return self.training_type_plugin.collective.broadcast(obj, src) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. @@ -361,7 +361,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo Return: A tensor of shape (world_size, batch, ...) """ - return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) + return self.training_type_plugin.collective.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary. diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 04c24059da7e2..fb466ba6dc6c7 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -206,7 +206,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop - should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) + should_stop = trainer.training_type_plugin.collective.reduce_boolean_decision(should_stop) trainer.should_stop = trainer.should_stop or should_stop if should_stop: self.stopped_epoch = trainer.current_epoch diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 42cd078d2179d..f43c816887267 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -294,7 +294,7 @@ def on_train_batch_end( skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() # in case we have time differences across ranks # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs - skip_time = trainer.training_type_plugin.broadcast(skip_time) + skip_time = trainer.training_type_plugin.collective.broadcast(skip_time) if skip_batch and skip_time: return @@ -509,7 +509,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) # If using multiple devices, make sure all processes are unanimous on the decision. - should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) + should_update_best_and_save = trainer.training_type_plugin.collective.reduce_boolean_decision( + should_update_best_and_save + ) return should_update_best_and_save @@ -612,7 +614,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") - ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path) + ckpt_path = trainer.training_type_plugin.collective.broadcast(ckpt_path) self.dirpath = ckpt_path @@ -748,4 +750,4 @@ def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.""" exists = self._fs.exists(filepath) - return trainer.training_type_plugin.broadcast(exists) + return trainer.training_type_plugin.collective.broadcast(exists) diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index 07e3008aa6cd1..b9383c85857ea 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -67,7 +67,7 @@ def on_train_start(self, trainer, pl_module) -> None: ) memory_info = xm.get_memory_info(pl_module.device) - total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001 + total_memory = trainer.training_type_plugin.collective.reduce(memory_info["kb_total"]) * 0.001 rank_zero_info(f"Average Total memory: {total_memory:.2f} MB") def on_train_epoch_start(self, trainer, pl_module) -> None: @@ -81,9 +81,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None: free_memory = memory_info["kb_free"] peak_memory = memory_info["kb_total"] - free_memory - free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001 - peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001 - epoch_time = trainer.training_type_plugin.reduce(epoch_time) + free_memory = trainer.training_type_plugin.collective.reduce(free_memory) * 0.001 + peak_memory = trainer.training_type_plugin.collective.reduce(peak_memory) * 0.001 + epoch_time = trainer.training_type_plugin.collective.reduce(epoch_time) logs["avg. free memory (MB)"] = free_memory logs["avg. peak memory (MB)"] = peak_memory diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6eec45465b247..d6cf11edd636b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -466,7 +466,7 @@ def log( dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, sync_dist=sync_dist and distributed_available(), - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist_fn=self.trainer.training_type_plugin.collective.reduce or sync_ddp, sync_dist_group=sync_dist_group, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5573b04952ddd..6119c02e51fd1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -238,7 +238,9 @@ def _load_from_state_dict( # On reload, we need to re-attach the `Metric`s back to the `ResultCollection`. # The references are provided through the `metric_attributes` dictionary. v.load_state_dict( - state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce + state_dict[prefix + k], + metrics=metric_attributes, + sync_fn=self.trainer.training_type_plugin.collective.reduce, ) if not self.trainer.is_global_zero: diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py index 98f8f519cdec1..2679a7b70f7ff 100644 --- a/pytorch_lightning/plugins/collective/torch_collective.py +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -107,7 +107,10 @@ def mean(t: torch.Tensor) -> torch.Tensor: return tensor def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.device) - decision = self.reduce(decision, reduce_op=ReduceOp.SUM) - decision = bool(decision == self.world_size) - return decision + if self.local_reduce: + return decision + else: + decision1 = torch.tensor(int(decision), device=self.device) + decision2 = self.reduce(decision1, reduce_op=ReduceOp.SUM) + decision = bool(decision2 == self.world_size) + return decision diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 647ff764f7c89..2ae448ece72cf 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,9 +31,10 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -48,13 +49,7 @@ rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -91,6 +86,7 @@ def __init__( num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, @@ -102,6 +98,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(), ) self.interactive_ddp_procs = [] if num_nodes is not None: @@ -116,7 +113,6 @@ def __init__( " Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False - self.dist = LightningDistributed() self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._task_idx = None @@ -267,8 +263,10 @@ def setup_distributed(self): init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device + self.collective.rank = self.global_rank + self.collective.device = self.root_device + self.collective.device_id = self.determine_ddp_device_ids() + self.collective.world_size = self.world_size def _check_can_spawn_children(self): if self.local_rank != 0: @@ -389,17 +387,6 @@ def pre_dispatch(self): def post_dispatch(self, trainer: "pl.Trainer") -> None: self.cluster_environment.teardown() - def barrier(self, *args, **kwargs) -> None: - if not distributed_available(): - return - if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) - else: - torch.distributed.barrier() - - def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj) - def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" if not self.lightning_module.automatic_optimization: @@ -408,22 +395,6 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def model_to_device(self): self.model.to(self.root_device) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ - if isinstance(tensor, torch.Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - def training_step(self, *args, **kwargs) -> Optional[Any]: return self.model(*args, **kwargs) @@ -465,15 +436,15 @@ def _share_information_to_prevent_deadlock(self): sync_dirs = [] global_node_rank_zero = 0 for _ in range(self.num_nodes): - sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero)) + sync_dirs.append(self.collective.broadcast(self._sync_dir, global_node_rank_zero)) global_node_rank_zero += self.world_size // self.num_nodes self._sync_dir = sync_dirs[self.node_rank] def _share_pids(self): """Make all DDP processes aware of all processes pids.""" - self.barrier() - pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device)) + self.collective.barrier() + pids = self.collective.all_gather(torch.tensor(os.getpid(), device=self.root_device)) pids = pids.cpu().numpy().tolist() self._pids = pids if isinstance(pids, list) else [pids] diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index ae3954093880c..0015760933717 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -11,11 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - from pytorch_lightning.plugins.training_type.ddp import DDPPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DDP2Plugin(DDPPlugin): @@ -33,25 +29,7 @@ def setup(self) -> None: # set the task idx self.task_idx = self.cluster_environment.local_rank() # the difference to DDP is that we don't call children processes here - - def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: - """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2, - the reduction here is only across local devices within the node. - - Args: - collection: The collection of tensors to sync and reduce. - *args: ignored for DDP2 - **kwargs: ignored for DDP2 - - Return: - Reduced tensor values or the same value if it was not or did not contain a tensor. - """ - - def mean(t: torch.Tensor) -> torch.Tensor: - original_dtype = t.dtype - return t.float().mean().to(original_dtype) - - return apply_to_collection(collection, torch.Tensor, mean) + self.collective.local_reduce = True @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5f493001341d6..706f2420b1ea5 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -24,9 +24,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -40,13 +41,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -69,6 +64,7 @@ def __init__( num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, @@ -79,6 +75,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(), ) if num_nodes is not None: rank_zero_deprecation( @@ -93,7 +90,6 @@ def __init__( ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs - self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state @@ -194,8 +190,10 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ # self.trainer.call_setup_hook(self.model) # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device + self.collective.rank = self.global_rank + self.collective.device = self.root_device + self.collective.device_id = self.determine_ddp_device_ids() + self.collective.world_size = self.world_size # move the model to the correct device self.model_to_device() @@ -208,7 +206,7 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - self.barrier() + self.collective.barrier() results = trainer.run_stage() @@ -313,19 +311,6 @@ def __recover_child_process_weights(self, best_path, last_path): ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) - def barrier(self, *args, **kwargs) -> None: - if not distributed_available(): - return - if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) - else: - torch.distributed.barrier() - - def broadcast(self, obj: object, src: int = 0) -> object: - if not distributed_available(): - return obj - return self.dist.broadcast(obj) - def model_to_device(self): if self.root_device.type == "cuda": # set the device on the spawned subprocesses @@ -337,22 +322,6 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ - if isinstance(tensor, torch.Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - def training_step(self, *args, **kwargs) -> Optional[Any]: return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index ca10b47bd9fd2..c63f4e5135dc5 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,6 +26,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -114,6 +116,7 @@ def __init__( num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + collective: Optional[Collective] = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -263,6 +266,7 @@ def __init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment, + collective=collective or TorchCollective(), ) self.config = self._load_config(config) @@ -362,7 +366,7 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: def pre_dispatch(self): self.init_deepspeed() - self.barrier() + self.collective.barrier() def init_deepspeed(self): self._handle_gradient_accumulation_steps() @@ -689,7 +693,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[st if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing - checkpoint_path = self.broadcast(checkpoint_path) + checkpoint_path = self.collective.broadcast(checkpoint_path) return super().load_checkpoint(checkpoint_path) # Rely on deepspeed to load the checkpoint and necessary information diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index fe970bb5a3bbc..3d562dfda6c3e 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -17,11 +17,11 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DataParallelPlugin(ParallelPlugin): @@ -32,8 +32,15 @@ def __init__( self, parallel_devices: Optional[List[torch.device]], checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(local_reduce=True), + ) @property def global_rank(self) -> int: @@ -56,24 +63,6 @@ def setup(self) -> None: self.model_to_device() self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices) - def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: - """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. - - Args: - collection: The collection of tensors to sync and reduce. - *args: ignored for DP - **kwargs: ignored for DP - - Return: - Reduced tensor values or the same value if it was not or did not contain a tensor. - """ - - def mean(t: torch.Tensor) -> torch.Tensor: - original_dtype = t.dtype - return t.float().mean().to(original_dtype) - - return apply_to_collection(collection, torch.Tensor, mean) - @property def root_device(self): return self.parallel_devices[0] @@ -81,15 +70,6 @@ def root_device(self): def model_to_device(self) -> None: self._model.to(self.root_device) - def barrier(self, *args, **kwargs): - pass - - def broadcast(self, obj: object, src: int = 0) -> object: - return obj - - def reduce_boolean_decision(self, decision: bool) -> bool: - return decision - def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 72338e2923c07..7554113093fc8 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -16,6 +16,8 @@ import torch +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -42,6 +44,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): """Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -93,6 +96,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(), ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu @@ -167,7 +171,7 @@ def pre_dispatch(self) -> None: if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) self.configure_ddp() - self.barrier() + self.collective.barrier() def model_to_device(self) -> None: # ensure we update the device type in the lightning module diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index c7a6fefea6542..efb37ceeed097 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Any, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -20,12 +20,12 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import distributed_available -from pytorch_lightning.utilities.distributed import group as dist_group -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_only if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -38,8 +38,14 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + collective=collective or HorovodCollective(), + ) rank_zero_only.rank = self.global_rank @property @@ -65,6 +71,8 @@ def distributed_sampler_kwargs(self): def setup(self) -> None: self.model_to_device() + self.collective.on_gpu = self.on_gpu + self.collective.local_rank = self.local_rank def pre_dispatch(self): @@ -108,14 +116,14 @@ def start_training(self, trainer): self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user - self.join() + self.collective.join() def start_evaluating(self, trainer): with ExitStack(): self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user - self.join() + self.collective.join() def start_predicting(self, trainer): with ExitStack(): @@ -123,15 +131,7 @@ def start_predicting(self, trainer): self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user - self.join() - - def barrier(self, *args, **kwargs): - if distributed_available(): - self.join() - - def broadcast(self, obj: object, src: int = 0) -> object: - obj = hvd.broadcast_object(obj, src) - return obj + self.collective.join() def model_to_device(self): if self.on_gpu: @@ -139,54 +139,6 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def join(self): - if self.on_gpu: - hvd.join(self.local_rank) - else: - hvd.join() - - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ - if group is not None: - raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.") - - if reduce_op in (None, "avg", "mean"): - reduce_op = hvd.Average - elif reduce_op in ("sum", ReduceOp.SUM): - reduce_op = hvd.Sum - else: - raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") - - # sync all processes before reduction - self.join() - return hvd.allreduce(tensor, op=reduce_op) - - def all_gather( - self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False - ) -> torch.Tensor: - if group is not None and group != dist_group.WORLD: - raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") - - if len(result.shape) == 0: - # Convert scalars to single dimension tensors - result = result.reshape(1) - - # sync and gather all - self.join() - gathered = hvd.allgather(result) - gathered_result = list(gathered.split(1, dim=0)) - return gathered_result - def post_backward(self, closure_loss: torch.Tensor) -> None: # synchronize all horovod optimizers. for optimizer in self.lightning_module.trainer.optimizers: diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 16bc5e4e9be4b..315ab2b3f7afd 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -21,6 +21,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -65,6 +67,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -85,6 +88,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or SingleDeviceCollective(), ) if not _POPTORCH_AVAILABLE or not poptorch.ipuHardwareIsAvailable(): raise MisconfigurationException( @@ -318,15 +322,3 @@ def model_to_device(self) -> None: @property def is_global_zero(self) -> bool: return True - - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - return tensor - - def barrier(self, name: Optional[str] = None) -> None: - pass - - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - return tensor - - def broadcast(self, obj: object, src: int = 0) -> object: - return obj diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index e6406ea444947..b9f4801675040 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -14,18 +14,18 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional +from typing import List, Optional import torch from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins.collective.collective_plugin import Collective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp class ParallelPlugin(TrainingTypePlugin, ABC): @@ -36,8 +36,12 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): - super().__init__(checkpoint_io) + super().__init__( + checkpoint_io=checkpoint_io, + collective=collective, + ) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment @@ -86,16 +90,6 @@ def distributed_sampler_kwargs(self): def reconciliate_processes(self, trace: str): """Function to re-conciliate processes on failure.""" - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.lightning_module.device) - decision = self.reduce(decision, reduce_op=ReduceOp.SUM) - decision = bool(decision == self.world_size) - return decision - @property def torch_distributed_backend(self): torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND") diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1737bf3b41ca8..d3d70cb613455 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Optional import torch +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -27,8 +29,9 @@ def __init__( self, device: torch.device, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, collective=collective or SingleDeviceCollective()) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 @@ -42,24 +45,6 @@ def on_tpu(self) -> bool: def on_gpu(self) -> bool: return self.root_device.type == "cuda" and torch.cuda.is_available() - def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: - """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only - operates with a single device, the reduction is simply the identity. - - Args: - tensor: the tensor to sync and reduce - *args: ignored - **kwargs: ignored - - Return: - the unmodified input as reduction is not needed for single process operation - """ - return tensor - - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" - return tensor - @property def root_device(self) -> torch.device: return self.device @@ -74,12 +59,6 @@ def setup(self) -> None: def is_global_zero(self) -> bool: return True - def barrier(self, *args, **kwargs) -> None: - pass - - def broadcast(self, obj: object, src: int = 0) -> object: - return obj - def teardown(self) -> None: if self.on_gpu: # GPU teardown diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9140a995c8b56..3bc664dfc658c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os import re import time @@ -25,6 +24,8 @@ import pytorch_lightning as pl from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader @@ -32,7 +33,7 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.data import has_len -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -53,8 +54,14 @@ class TPUSpawnPlugin(DDPSpawnPlugin): """Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method.""" - def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None: - super().__init__(parallel_devices=parallel_devices) + def __init__( + self, + parallel_devices: Optional[List[int]] = None, + collective: Optional[Collective] = None, + debug: bool = False, + **_: Any + ) -> None: + super().__init__(parallel_devices=parallel_devices, collective=collective or TPUCollective()) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -157,6 +164,10 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: trainer.progress_bar_callback.disable() self.model_to_device() + self.collective.device = self.lightning_module.device + self.collective.root_device = self.root_device + self.collective.world_size = self.world_size + trainer.accelerator.setup_optimizers(trainer) trainer.precision_plugin.connect(self._model, None, None) @@ -210,42 +221,6 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul def save(self, state_dict: Dict, path: str) -> None: xm.save(state_dict, path) - def broadcast(self, obj: object, src: int = 0) -> object: - if not self.is_distributed: - return obj - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) - data = xm.all_gather(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj - - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.lightning_module.device) - decision = self.reduce(decision, reduce_op="sum") - decision = bool(decision == self.world_size) - return decision - - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): - if not isinstance(output, torch.Tensor): - output = torch.tensor(output, device=self.lightning_module.device) - - _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM - _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") - if _invalid_reduce_op or _invalid_reduce_op_str: - raise MisconfigurationException( - "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." - ) - - output = xm.mesh_reduce("reduce", output, sum) - - if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - output = output / self.world_size - - return output - def _close_logger(self, trainer) -> None: if trainer.logger is not None: trainer.logger.finalize("success") @@ -315,20 +290,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - return xm.all_gather(tensor) - def teardown(self) -> None: # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e5c015c34ab59..d40e99a201111 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO +from pytorch_lightning.plugins.collective.collective_plugin import Collective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -35,11 +36,12 @@ class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- loop.""" - def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: + def __init__(self, checkpoint_io: Optional[CheckpointIO] = None, collective: Optional[Collective] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io + self.collective = collective self._call_configure_sharded_model_hook = True @property @@ -91,32 +93,6 @@ def model_to_device(self) -> None: def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" - @abstractmethod - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - """Reduces the given tensor (e.g. across GPUs/processes). - - Args: - tensor: the tensor to sync and reduce - *args: plugin-specific positional arguments - **kwargs: plugin-specific keyword arguments - """ - - @abstractmethod - def barrier(self, name: Optional[str] = None) -> None: - """Forces all possibly joined processes to wait for each other.""" - - @abstractmethod - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - """Broadcasts an object to all processes.""" - - @abstractmethod - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" - - def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes.""" - return decision - def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index dee8aea79a304..cdf911fd0feb0 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -76,7 +76,7 @@ def resume_end(self) -> None: torch.cuda.empty_cache() # wait for all to catch up - self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") + self.trainer.training_type_plugin.collective.barrier("CheckpointConnector.resume_end") def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 12ec14712fc94..530efa5633188 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -111,7 +111,7 @@ def training_epoch_end(self, outputs) -> None: self.log("my_loss_2", (1 + local_rank), on_epoch=True, rank_zero_only=True) data = str(self.global_rank) obj = [[data], (data,), set(data)] - out = self.trainer.training_type_plugin.broadcast(obj) + out = self.trainer.training_type_plugin.collective.broadcast(obj) assert obj == [[str(self.global_rank)], (str(self.global_rank),), set(str(self.global_rank))] assert out == [["0"], ("0",), set("0")] diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index a7572f9a77394..2f9b577d6a0d4 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -316,7 +316,7 @@ def on_save_checkpoint(self, checkpoint) -> None: assert state_dict["items"]["validation_step.v"]["value"].device.type == device # sync fn should be kept - assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce + assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.collective.reduce # sync fn dropped from the state dict assert "fn" not in state_dict["items"]["validation_step.v"]["meta"]["_sync"] @@ -326,7 +326,7 @@ def on_save_checkpoint(self, checkpoint) -> None: assert results["validation_step.v"].value.device.type == device # sync fn was preserved in the original result - assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce + assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.collective.reduce # default sync fn new_results = ResultCollection(False, device) @@ -453,7 +453,7 @@ def on_epoch_end(self) -> None: assert not model.has_validated_sum tmpdir = ( - trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0) + trainer.training_type_plugin.collective.broadcast(trainer_kwargs["default_root_dir"], 0) if num_processes >= 2 else trainer_kwargs["default_root_dir"] ) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 71acb9a168081..b0261cff5ccb4 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -58,7 +58,9 @@ def on_train_start(self) -> None: assert self.device == expected_device def training_epoch_end(self, outputs) -> None: - res = self.trainer.training_type_plugin.reduce(torch.tensor(1.0, device=self.device), reduce_op="sum") + res = self.trainer.training_type_plugin.collective.reduce( + torch.tensor(1.0, device=self.device), reduce_op="sum" + ) assert res.sum() == self.trainer.training_type_plugin.world_size model = TestModel() diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 950c3577b89b9..918336cc0be08 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -292,7 +292,7 @@ def test_broadcast(rank): assert isinstance(trainer.accelerator, TPUAccelerator) assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) obj = ("ver_0.5", "logger_name", rank) - result = trainer.training_type_plugin.broadcast(obj) + result = trainer.training_type_plugin.collective.broadcast(obj) assert result == ("ver_0.5", "logger_name", 0) xmp.spawn(test_broadcast, nprocs=8, start_method="fork") @@ -356,9 +356,9 @@ def test_reduce(rank): for reduce_op in reduce_ops: if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): - result = trainer.training_type_plugin.reduce(1, reduce_op) + result = trainer.training_type_plugin.collective.reduce(1, reduce_op) else: - result = trainer.training_type_plugin.reduce(1, reduce_op) + result = trainer.training_type_plugin.collective.reduce(1, reduce_op) if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): assert result.item() == 1 else: