diff --git a/CHANGELOG.md b/CHANGELOG.md index a8b068e5eccdc..c3a0f68aa8b3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344)) +- Add collective base class and subclasses ([#9414](https://github.com/PyTorchLightning/pytorch-lightning/pull/9414)) + + ### Changed - `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)). @@ -177,6 +180,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360)) +- Integrate `collective` class with `TrainingTypePlugin` ([#9472](https://github.com/PyTorchLightning/pytorch-lightning/pull/9472)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0a6060ee12304..cfbe84a183218 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -338,7 +338,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: return self.training_type_plugin.lightning_module_state_dict() def barrier(self, name: Optional[str] = None) -> None: - self.training_type_plugin.barrier(name=name) + self.training_type_plugin.collective.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if @@ -348,7 +348,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj: Object to broadcast to all process, usually a tensor or collection of tensors. src: The source rank of which the object will be broadcast from """ - return self.training_type_plugin.broadcast(obj, src) + return self.training_type_plugin.collective.broadcast(obj, src) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. @@ -361,7 +361,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo Return: A tensor of shape (world_size, batch, ...) """ - return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) + return self.training_type_plugin.collective.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary. diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 04c24059da7e2..fb466ba6dc6c7 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -206,7 +206,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop - should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) + should_stop = trainer.training_type_plugin.collective.reduce_boolean_decision(should_stop) trainer.should_stop = trainer.should_stop or should_stop if should_stop: self.stopped_epoch = trainer.current_epoch diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 42cd078d2179d..f43c816887267 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -294,7 +294,7 @@ def on_train_batch_end( skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() # in case we have time differences across ranks # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs - skip_time = trainer.training_type_plugin.broadcast(skip_time) + skip_time = trainer.training_type_plugin.collective.broadcast(skip_time) if skip_batch and skip_time: return @@ -509,7 +509,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) # If using multiple devices, make sure all processes are unanimous on the decision. - should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) + should_update_best_and_save = trainer.training_type_plugin.collective.reduce_boolean_decision( + should_update_best_and_save + ) return should_update_best_and_save @@ -612,7 +614,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") - ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path) + ckpt_path = trainer.training_type_plugin.collective.broadcast(ckpt_path) self.dirpath = ckpt_path @@ -748,4 +750,4 @@ def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.""" exists = self._fs.exists(filepath) - return trainer.training_type_plugin.broadcast(exists) + return trainer.training_type_plugin.collective.broadcast(exists) diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index 07e3008aa6cd1..b9383c85857ea 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -67,7 +67,7 @@ def on_train_start(self, trainer, pl_module) -> None: ) memory_info = xm.get_memory_info(pl_module.device) - total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001 + total_memory = trainer.training_type_plugin.collective.reduce(memory_info["kb_total"]) * 0.001 rank_zero_info(f"Average Total memory: {total_memory:.2f} MB") def on_train_epoch_start(self, trainer, pl_module) -> None: @@ -81,9 +81,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None: free_memory = memory_info["kb_free"] peak_memory = memory_info["kb_total"] - free_memory - free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001 - peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001 - epoch_time = trainer.training_type_plugin.reduce(epoch_time) + free_memory = trainer.training_type_plugin.collective.reduce(free_memory) * 0.001 + peak_memory = trainer.training_type_plugin.collective.reduce(peak_memory) * 0.001 + epoch_time = trainer.training_type_plugin.collective.reduce(epoch_time) logs["avg. free memory (MB)"] = free_memory logs["avg. peak memory (MB)"] = peak_memory diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6eec45465b247..d6cf11edd636b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -466,7 +466,7 @@ def log( dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, sync_dist=sync_dist and distributed_available(), - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist_fn=self.trainer.training_type_plugin.collective.reduce or sync_ddp, sync_dist_group=sync_dist_group, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5573b04952ddd..6119c02e51fd1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -238,7 +238,9 @@ def _load_from_state_dict( # On reload, we need to re-attach the `Metric`s back to the `ResultCollection`. # The references are provided through the `metric_attributes` dictionary. v.load_state_dict( - state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce + state_dict[prefix + k], + metrics=metric_attributes, + sync_fn=self.trainer.training_type_plugin.collective.reduce, ) if not self.trainer.is_global_zero: diff --git a/pytorch_lightning/plugins/collective/__init__.py b/pytorch_lightning/plugins/collective/__init__.py new file mode 100644 index 0000000000000..ab5ad5c2146ed --- /dev/null +++ b/pytorch_lightning/plugins/collective/__init__.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401 +from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401 +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 +from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401 diff --git a/pytorch_lightning/plugins/collective/collective_plugin.py b/pytorch_lightning/plugins/collective/collective_plugin.py new file mode 100644 index 0000000000000..c3512452883d3 --- /dev/null +++ b/pytorch_lightning/plugins/collective/collective_plugin.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Union + +import torch + +from pytorch_lightning.utilities.distributed import ReduceOp + + +class Collective(ABC): + """Base class for collective functions for training type plugins.""" + + @abstractmethod + def barrier(self, name: Optional[str] = None) -> None: + """Forces all possibly joined processes to wait for each other.""" + + @abstractmethod + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes.""" + + @abstractmethod + def all_gather( + self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[List[torch.Tensor], torch.Tensor]: + """Perform a all_gather on all processes.""" + + @abstractmethod + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ + + @abstractmethod + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" diff --git a/pytorch_lightning/plugins/collective/horovod_collective.py b/pytorch_lightning/plugins/collective/horovod_collective.py new file mode 100644 index 0000000000000..3f1398f3be8cf --- /dev/null +++ b/pytorch_lightning/plugins/collective/horovod_collective.py @@ -0,0 +1,107 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as dist_group +from pytorch_lightning.utilities.distributed import ReduceOp + +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + + +class HorovodCollective(Collective): + """Collective interface for Horovod training type plugins.""" + + def __init__( + self, + on_gpu: Optional[bool] = False, + local_rank: int = 0, + ) -> None: + self.on_gpu = on_gpu + self.local_rank = local_rank + + def join(self) -> None: + """Horovod function that indicates that the rank finished processing data. + + All ranks that did not call join() continue to process allreduce operations. This function blocks the Python + thread until all ranks join. + """ + if self.on_gpu: + hvd.join(self.local_rank) + else: + hvd.join() + + def barrier(self, *args: Any, **kwargs: Any) -> None: + if distributed_available(): + self.join() + + def broadcast(self, obj: object, src: int = 0) -> object: + obj = hvd.broadcast_object(obj, src) + return obj + + def all_gather( + self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False + ) -> List[torch.Tensor]: + if group is not None and group != dist_group.WORLD: + raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") + + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + self.join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result + + def reduce( + self, + tensor: Union[torch.Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, Any]: + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if group is not None: + raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.") + + if reduce_op in (None, "avg", "mean"): + reduce_op = hvd.Average + elif reduce_op in ("sum", ReduceOp.SUM): + reduce_op = hvd.Sum + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + self.join() + return hvd.allreduce(tensor, op=reduce_op) + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/single_device_collective.py b/pytorch_lightning/plugins/collective/single_device_collective.py new file mode 100644 index 0000000000000..18010f244c9a3 --- /dev/null +++ b/pytorch_lightning/plugins/collective/single_device_collective.py @@ -0,0 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective + + +class SingleDeviceCollective(Collective): + """Collective interface for single device training type plugins.""" + + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + """Forces all possibly joined processes to wait for each other.""" + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes.""" + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + return tensor + + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ + return tensor + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/collective/torch_collective.py b/pytorch_lightning/plugins/collective/torch_collective.py new file mode 100644 index 0000000000000..2679a7b70f7ff --- /dev/null +++ b/pytorch_lightning/plugins/collective/torch_collective.py @@ -0,0 +1,116 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.types import _METRIC_COLLECTION + + +class TorchCollective(Collective): + """Collective interface for DDP, DDPSpawn, DP and DDP2.""" + + def __init__( + self, + local_reduce: bool = False, + rank: Optional[int] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device_id: Optional[int] = None, + world_size: int = 1, + ) -> None: + """ + Note: + DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False + DP run reduce in on node, local_reduce = True + DDP2 behaves like DP in one node, local_reduce = True + + local_reduce set in Plugins.setup() functions + """ + self.local_reduce = local_reduce + self.rank = rank + self.device = device + self.device_id = device_id + self.world_size = world_size + + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + if not distributed_available(): + return + if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self.device_id) + else: + torch.distributed.barrier() + + def broadcast(self, obj: Any, src: int = 0) -> Any: + if not distributed_available(): + return obj + else: + obj = [obj] + if self.rank != 0: + obj = [None] * len(obj) + broadcast_object_list(obj, 0, group=_group.WORLD) + return obj[0] + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes.""" + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def reduce( + self, + tensor: Union[torch.Tensor, _METRIC_COLLECTION], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[torch.Tensor, _METRIC_COLLECTION]: + """Reduces the given tensor (e.g. across GPUs/processes) + + If local_reduce = True (dp and ddp2), reduces tensor from all local processes. + + If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if self.local_reduce: + + def mean(t: torch.Tensor) -> torch.Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) + + return apply_to_collection(tensor, torch.Tensor, mean) + + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def reduce_boolean_decision(self, decision: bool) -> bool: + if self.local_reduce: + return decision + else: + decision1 = torch.tensor(int(decision), device=self.device) + decision2 = self.reduce(decision1, reduce_op=ReduceOp.SUM) + decision = bool(decision2 == self.world_size) + return decision diff --git a/pytorch_lightning/plugins/collective/tpu_collective.py b/pytorch_lightning/plugins/collective/tpu_collective.py new file mode 100644 index 0000000000000..6cacd2c688fb7 --- /dev/null +++ b/pytorch_lightning/plugins/collective/tpu_collective.py @@ -0,0 +1,99 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.collective import Collective +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _TPU_AVAILABLE: + import torch_xla.core.xla_env_vars as xenv + import torch_xla.core.xla_model as xm + from torch_xla.core.xla_model import rendezvous + + +class TPUCollective(Collective): + """Collective interface for TPUSpawning training type plugins.""" + + def __init__( + self, + device: Union[str, torch.device] = torch.device("xla"), + root_device: torch.device = torch.device("xla"), + world_size: int = 1, + ): + self.device = device + self.root_device = root_device + self.world_size = world_size + + @property + def is_distributed(self) -> bool: + # HOST_WORLD_SIZE is None outside the xmp.spawn process + return os.getenv(xenv.HOST_WORLD_SIZE, None) is not None and self.world_size != 1 + + def barrier(self, name: Optional[str] = None) -> None: + if self.is_distributed: + rendezvous(name) + + def broadcast(self, obj: object, src: int = 0) -> object: + if not self.is_distributed: + return obj + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return xm.all_gather(tensor) + + def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Any: + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.device) + + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes.""" + return decision diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 647ff764f7c89..2ae448ece72cf 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,9 +31,10 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -48,13 +49,7 @@ rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -91,6 +86,7 @@ def __init__( num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, @@ -102,6 +98,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(), ) self.interactive_ddp_procs = [] if num_nodes is not None: @@ -116,7 +113,6 @@ def __init__( " Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False - self.dist = LightningDistributed() self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._task_idx = None @@ -267,8 +263,10 @@ def setup_distributed(self): init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device + self.collective.rank = self.global_rank + self.collective.device = self.root_device + self.collective.device_id = self.determine_ddp_device_ids() + self.collective.world_size = self.world_size def _check_can_spawn_children(self): if self.local_rank != 0: @@ -389,17 +387,6 @@ def pre_dispatch(self): def post_dispatch(self, trainer: "pl.Trainer") -> None: self.cluster_environment.teardown() - def barrier(self, *args, **kwargs) -> None: - if not distributed_available(): - return - if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) - else: - torch.distributed.barrier() - - def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj) - def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" if not self.lightning_module.automatic_optimization: @@ -408,22 +395,6 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def model_to_device(self): self.model.to(self.root_device) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ - if isinstance(tensor, torch.Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - def training_step(self, *args, **kwargs) -> Optional[Any]: return self.model(*args, **kwargs) @@ -465,15 +436,15 @@ def _share_information_to_prevent_deadlock(self): sync_dirs = [] global_node_rank_zero = 0 for _ in range(self.num_nodes): - sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero)) + sync_dirs.append(self.collective.broadcast(self._sync_dir, global_node_rank_zero)) global_node_rank_zero += self.world_size // self.num_nodes self._sync_dir = sync_dirs[self.node_rank] def _share_pids(self): """Make all DDP processes aware of all processes pids.""" - self.barrier() - pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device)) + self.collective.barrier() + pids = self.collective.all_gather(torch.tensor(os.getpid(), device=self.root_device)) pids = pids.cpu().numpy().tolist() self._pids = pids if isinstance(pids, list) else [pids] diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index ae3954093880c..0015760933717 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -11,11 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - from pytorch_lightning.plugins.training_type.ddp import DDPPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DDP2Plugin(DDPPlugin): @@ -33,25 +29,7 @@ def setup(self) -> None: # set the task idx self.task_idx = self.cluster_environment.local_rank() # the difference to DDP is that we don't call children processes here - - def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: - """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2, - the reduction here is only across local devices within the node. - - Args: - collection: The collection of tensors to sync and reduce. - *args: ignored for DDP2 - **kwargs: ignored for DDP2 - - Return: - Reduced tensor values or the same value if it was not or did not contain a tensor. - """ - - def mean(t: torch.Tensor) -> torch.Tensor: - original_dtype = t.dtype - return t.float().mean().to(original_dtype) - - return apply_to_collection(collection, torch.Tensor, mean) + self.collective.local_reduce = True @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5f493001341d6..706f2420b1ea5 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -24,9 +24,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -40,13 +41,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -69,6 +64,7 @@ def __init__( num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, @@ -79,6 +75,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(), ) if num_nodes is not None: rank_zero_deprecation( @@ -93,7 +90,6 @@ def __init__( ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs - self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state @@ -194,8 +190,10 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ # self.trainer.call_setup_hook(self.model) # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device + self.collective.rank = self.global_rank + self.collective.device = self.root_device + self.collective.device_id = self.determine_ddp_device_ids() + self.collective.world_size = self.world_size # move the model to the correct device self.model_to_device() @@ -208,7 +206,7 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - self.barrier() + self.collective.barrier() results = trainer.run_stage() @@ -313,19 +311,6 @@ def __recover_child_process_weights(self, best_path, last_path): ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) - def barrier(self, *args, **kwargs) -> None: - if not distributed_available(): - return - if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) - else: - torch.distributed.barrier() - - def broadcast(self, obj: object, src: int = 0) -> object: - if not distributed_available(): - return obj - return self.dist.broadcast(obj) - def model_to_device(self): if self.root_device.type == "cuda": # set the device on the spawned subprocesses @@ -337,22 +322,6 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ - if isinstance(tensor, torch.Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - def training_step(self, *args, **kwargs) -> Optional[Any]: return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index ca10b47bd9fd2..c63f4e5135dc5 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,6 +26,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -114,6 +116,7 @@ def __init__( num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + collective: Optional[Collective] = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -263,6 +266,7 @@ def __init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment, + collective=collective or TorchCollective(), ) self.config = self._load_config(config) @@ -362,7 +366,7 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: def pre_dispatch(self): self.init_deepspeed() - self.barrier() + self.collective.barrier() def init_deepspeed(self): self._handle_gradient_accumulation_steps() @@ -689,7 +693,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[st if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing - checkpoint_path = self.broadcast(checkpoint_path) + checkpoint_path = self.collective.broadcast(checkpoint_path) return super().load_checkpoint(checkpoint_path) # Rely on deepspeed to load the checkpoint and necessary information diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index fe970bb5a3bbc..3d562dfda6c3e 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -17,11 +17,11 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _METRIC_COLLECTION class DataParallelPlugin(ParallelPlugin): @@ -32,8 +32,15 @@ def __init__( self, parallel_devices: Optional[List[torch.device]], checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(local_reduce=True), + ) @property def global_rank(self) -> int: @@ -56,24 +63,6 @@ def setup(self) -> None: self.model_to_device() self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices) - def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: - """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. - - Args: - collection: The collection of tensors to sync and reduce. - *args: ignored for DP - **kwargs: ignored for DP - - Return: - Reduced tensor values or the same value if it was not or did not contain a tensor. - """ - - def mean(t: torch.Tensor) -> torch.Tensor: - original_dtype = t.dtype - return t.float().mean().to(original_dtype) - - return apply_to_collection(collection, torch.Tensor, mean) - @property def root_device(self): return self.parallel_devices[0] @@ -81,15 +70,6 @@ def root_device(self): def model_to_device(self) -> None: self._model.to(self.root_device) - def barrier(self, *args, **kwargs): - pass - - def broadcast(self, obj: object, src: int = 0) -> object: - return obj - - def reduce_boolean_decision(self, decision: bool) -> bool: - return decision - def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 72338e2923c07..7554113093fc8 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -16,6 +16,8 @@ import torch +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.torch_collective import TorchCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -42,6 +44,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): """Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -93,6 +96,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or TorchCollective(), ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu @@ -167,7 +171,7 @@ def pre_dispatch(self) -> None: if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) self.configure_ddp() - self.barrier() + self.collective.barrier() def model_to_device(self) -> None: # ensure we update the device type in the lightning module diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index c7a6fefea6542..efb37ceeed097 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Any, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -20,12 +20,12 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import distributed_available -from pytorch_lightning.utilities.distributed import group as dist_group -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_only if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -38,8 +38,14 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + collective=collective or HorovodCollective(), + ) rank_zero_only.rank = self.global_rank @property @@ -65,6 +71,8 @@ def distributed_sampler_kwargs(self): def setup(self) -> None: self.model_to_device() + self.collective.on_gpu = self.on_gpu + self.collective.local_rank = self.local_rank def pre_dispatch(self): @@ -108,14 +116,14 @@ def start_training(self, trainer): self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user - self.join() + self.collective.join() def start_evaluating(self, trainer): with ExitStack(): self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user - self.join() + self.collective.join() def start_predicting(self, trainer): with ExitStack(): @@ -123,15 +131,7 @@ def start_predicting(self, trainer): self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user - self.join() - - def barrier(self, *args, **kwargs): - if distributed_available(): - self.join() - - def broadcast(self, obj: object, src: int = 0) -> object: - obj = hvd.broadcast_object(obj, src) - return obj + self.collective.join() def model_to_device(self): if self.on_gpu: @@ -139,54 +139,6 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def join(self): - if self.on_gpu: - hvd.join(self.local_rank) - else: - hvd.join() - - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ - if group is not None: - raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.") - - if reduce_op in (None, "avg", "mean"): - reduce_op = hvd.Average - elif reduce_op in ("sum", ReduceOp.SUM): - reduce_op = hvd.Sum - else: - raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") - - # sync all processes before reduction - self.join() - return hvd.allreduce(tensor, op=reduce_op) - - def all_gather( - self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False - ) -> torch.Tensor: - if group is not None and group != dist_group.WORLD: - raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") - - if len(result.shape) == 0: - # Convert scalars to single dimension tensors - result = result.reshape(1) - - # sync and gather all - self.join() - gathered = hvd.allgather(result) - gathered_result = list(gathered.split(1, dim=0)) - return gathered_result - def post_backward(self, closure_loss: torch.Tensor) -> None: # synchronize all horovod optimizers. for optimizer in self.lightning_module.trainer.optimizers: diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 16bc5e4e9be4b..315ab2b3f7afd 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -21,6 +21,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -65,6 +67,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -85,6 +88,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + collective=collective or SingleDeviceCollective(), ) if not _POPTORCH_AVAILABLE or not poptorch.ipuHardwareIsAvailable(): raise MisconfigurationException( @@ -318,15 +322,3 @@ def model_to_device(self) -> None: @property def is_global_zero(self) -> bool: return True - - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - return tensor - - def barrier(self, name: Optional[str] = None) -> None: - pass - - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - return tensor - - def broadcast(self, obj: object, src: int = 0) -> object: - return obj diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index e6406ea444947..b9f4801675040 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -14,18 +14,18 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional +from typing import List, Optional import torch from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins.collective.collective_plugin import Collective from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp class ParallelPlugin(TrainingTypePlugin, ABC): @@ -36,8 +36,12 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): - super().__init__(checkpoint_io) + super().__init__( + checkpoint_io=checkpoint_io, + collective=collective, + ) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment @@ -86,16 +90,6 @@ def distributed_sampler_kwargs(self): def reconciliate_processes(self, trace: str): """Function to re-conciliate processes on failure.""" - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.lightning_module.device) - decision = self.reduce(decision, reduce_op=ReduceOp.SUM) - decision = bool(decision == self.world_size) - return decision - @property def torch_distributed_backend(self): torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND") diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1737bf3b41ca8..d3d70cb613455 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Optional import torch +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -27,8 +29,9 @@ def __init__( self, device: torch.device, checkpoint_io: Optional[CheckpointIO] = None, + collective: Optional[Collective] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, collective=collective or SingleDeviceCollective()) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 @@ -42,24 +45,6 @@ def on_tpu(self) -> bool: def on_gpu(self) -> bool: return self.root_device.type == "cuda" and torch.cuda.is_available() - def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: - """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only - operates with a single device, the reduction is simply the identity. - - Args: - tensor: the tensor to sync and reduce - *args: ignored - **kwargs: ignored - - Return: - the unmodified input as reduction is not needed for single process operation - """ - return tensor - - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" - return tensor - @property def root_device(self) -> torch.device: return self.device @@ -74,12 +59,6 @@ def setup(self) -> None: def is_global_zero(self) -> bool: return True - def barrier(self, *args, **kwargs) -> None: - pass - - def broadcast(self, obj: object, src: int = 0) -> object: - return obj - def teardown(self) -> None: if self.on_gpu: # GPU teardown diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9140a995c8b56..3bc664dfc658c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os import re import time @@ -25,6 +24,8 @@ import pytorch_lightning as pl from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.plugins.collective.collective_plugin import Collective +from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader @@ -32,7 +33,7 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.data import has_len -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -53,8 +54,14 @@ class TPUSpawnPlugin(DDPSpawnPlugin): """Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method.""" - def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None: - super().__init__(parallel_devices=parallel_devices) + def __init__( + self, + parallel_devices: Optional[List[int]] = None, + collective: Optional[Collective] = None, + debug: bool = False, + **_: Any + ) -> None: + super().__init__(parallel_devices=parallel_devices, collective=collective or TPUCollective()) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -157,6 +164,10 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: trainer.progress_bar_callback.disable() self.model_to_device() + self.collective.device = self.lightning_module.device + self.collective.root_device = self.root_device + self.collective.world_size = self.world_size + trainer.accelerator.setup_optimizers(trainer) trainer.precision_plugin.connect(self._model, None, None) @@ -210,42 +221,6 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul def save(self, state_dict: Dict, path: str) -> None: xm.save(state_dict, path) - def broadcast(self, obj: object, src: int = 0) -> object: - if not self.is_distributed: - return obj - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) - data = xm.all_gather(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj - - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.lightning_module.device) - decision = self.reduce(decision, reduce_op="sum") - decision = bool(decision == self.world_size) - return decision - - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): - if not isinstance(output, torch.Tensor): - output = torch.tensor(output, device=self.lightning_module.device) - - _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM - _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") - if _invalid_reduce_op or _invalid_reduce_op_str: - raise MisconfigurationException( - "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." - ) - - output = xm.mesh_reduce("reduce", output, sum) - - if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - output = output / self.world_size - - return output - def _close_logger(self, trainer) -> None: if trainer.logger is not None: trainer.logger.finalize("success") @@ -315,20 +290,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - return xm.all_gather(tensor) - def teardown(self) -> None: # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e5c015c34ab59..d40e99a201111 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO +from pytorch_lightning.plugins.collective.collective_plugin import Collective from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -35,11 +36,12 @@ class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- loop.""" - def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: + def __init__(self, checkpoint_io: Optional[CheckpointIO] = None, collective: Optional[Collective] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io + self.collective = collective self._call_configure_sharded_model_hook = True @property @@ -91,32 +93,6 @@ def model_to_device(self) -> None: def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" - @abstractmethod - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - """Reduces the given tensor (e.g. across GPUs/processes). - - Args: - tensor: the tensor to sync and reduce - *args: plugin-specific positional arguments - **kwargs: plugin-specific keyword arguments - """ - - @abstractmethod - def barrier(self, name: Optional[str] = None) -> None: - """Forces all possibly joined processes to wait for each other.""" - - @abstractmethod - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - """Broadcasts an object to all processes.""" - - @abstractmethod - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes.""" - - def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes.""" - return decision - def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index dee8aea79a304..cdf911fd0feb0 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -76,7 +76,7 @@ def resume_end(self) -> None: torch.cuda.empty_cache() # wait for all to catch up - self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") + self.trainer.training_type_plugin.collective.barrier("CheckpointConnector.resume_end") def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 12ec14712fc94..530efa5633188 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -111,7 +111,7 @@ def training_epoch_end(self, outputs) -> None: self.log("my_loss_2", (1 + local_rank), on_epoch=True, rank_zero_only=True) data = str(self.global_rank) obj = [[data], (data,), set(data)] - out = self.trainer.training_type_plugin.broadcast(obj) + out = self.trainer.training_type_plugin.collective.broadcast(obj) assert obj == [[str(self.global_rank)], (str(self.global_rank),), set(str(self.global_rank))] assert out == [["0"], ("0",), set("0")] diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index a7572f9a77394..2f9b577d6a0d4 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -316,7 +316,7 @@ def on_save_checkpoint(self, checkpoint) -> None: assert state_dict["items"]["validation_step.v"]["value"].device.type == device # sync fn should be kept - assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce + assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.collective.reduce # sync fn dropped from the state dict assert "fn" not in state_dict["items"]["validation_step.v"]["meta"]["_sync"] @@ -326,7 +326,7 @@ def on_save_checkpoint(self, checkpoint) -> None: assert results["validation_step.v"].value.device.type == device # sync fn was preserved in the original result - assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce + assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.collective.reduce # default sync fn new_results = ResultCollection(False, device) @@ -453,7 +453,7 @@ def on_epoch_end(self) -> None: assert not model.has_validated_sum tmpdir = ( - trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0) + trainer.training_type_plugin.collective.broadcast(trainer_kwargs["default_root_dir"], 0) if num_processes >= 2 else trainer_kwargs["default_root_dir"] ) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 71acb9a168081..b0261cff5ccb4 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -58,7 +58,9 @@ def on_train_start(self) -> None: assert self.device == expected_device def training_epoch_end(self, outputs) -> None: - res = self.trainer.training_type_plugin.reduce(torch.tensor(1.0, device=self.device), reduce_op="sum") + res = self.trainer.training_type_plugin.collective.reduce( + torch.tensor(1.0, device=self.device), reduce_op="sum" + ) assert res.sum() == self.trainer.training_type_plugin.world_size model = TestModel() diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 950c3577b89b9..918336cc0be08 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -292,7 +292,7 @@ def test_broadcast(rank): assert isinstance(trainer.accelerator, TPUAccelerator) assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) obj = ("ver_0.5", "logger_name", rank) - result = trainer.training_type_plugin.broadcast(obj) + result = trainer.training_type_plugin.collective.broadcast(obj) assert result == ("ver_0.5", "logger_name", 0) xmp.spawn(test_broadcast, nprocs=8, start_method="fork") @@ -356,9 +356,9 @@ def test_reduce(rank): for reduce_op in reduce_ops: if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): - result = trainer.training_type_plugin.reduce(1, reduce_op) + result = trainer.training_type_plugin.collective.reduce(1, reduce_op) else: - result = trainer.training_type_plugin.reduce(1, reduce_op) + result = trainer.training_type_plugin.collective.reduce(1, reduce_op) if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): assert result.item() == 1 else: