Skip to content

2/n Consolidate collective functions - collective base and subclasses #9414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pytorch_lightning/plugins/collective/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.collective.collective_plugin import CollectivePlugin # noqa: F401
from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401
from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401
from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401
from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401
79 changes: 79 additions & 0 deletions pytorch_lightning/plugins/collective/collective_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union

import torch

from pytorch_lightning.utilities.distributed import ReduceOp


class CollectivePlugin(ABC):
"""Interface for collective functions.

Lightning collective supports communications between multiple processes and multiple nodes, provides routines such
as barrier, broadcast, all_gather, and reduce

.. note::
This API is experimental/in-beta and subject to change
"""

@abstractmethod
def barrier(self, name: Optional[str] = None) -> None:
"""Synchronizes all processes which blocks processes until the whole group enters this function.

Args:
name: a str pass into barrier. Only torch xla respect this param
"""

@abstractmethod
def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes.

Args:
obj: the object to broadcast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary true.

  1. Horovod broadcast_object will return the object that was broadcast from the root_rank
  2. for pytorch version< 1.8, torch collective using broadcast, which returns Async work handle, if async_op is set to True. None, if not async_op or if not part of the group
  3. for pytorch version .=1.8 torch collective using broadcast, which returns None. If rank is part of the group, object_list will contain the broadcasted objects from src rank.

src: source rank.
"""

@abstractmethod
def all_gather(
self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Perform a all_gather on all processes.

Args:
tensor: the tensor to all_gather
process_group: the process group to gather results from
sync_grads: flag that allows users to synchronize gradients for all_gather op

Returns: a tensor (torch distributed) or a list of tensor (horovod)
"""

@abstractmethod
def reduce(
self,
tensor: Union[torch.Tensor, Any],
process_group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[torch.Tensor, Any]:
"""Reduces the given tensor (e.g. across GPUs/processes).

Args:
tensor: the tensor to sync and reduce
process_group: the process group to reduce
reduce_op: the reduction operation. Defaults to 'mean'.
Can also be a string 'sum' or ReduceOp.
*args: plugin-specific positional arguments
**kwargs: plugin-specific keyword arguments
"""
93 changes: 93 additions & 0 deletions pytorch_lightning/plugins/collective/horovod_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union

import torch

from pytorch_lightning.plugins.collective import CollectivePlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd


class HorovodCollective(CollectivePlugin):
"""Collective interface for Horovod training type plugins."""

def __init__(
self,
on_gpu: bool = False,
local_rank: int = 0,
) -> None:
self.on_gpu = on_gpu
self.local_rank = local_rank

def _join(self) -> None:
"""Horovod function that indicates that the rank finished processing data.

All ranks that did not call join() continue to process allreduce operations. This function blocks the Python
thread until all ranks join.
"""
if self.on_gpu:
hvd.join(self.local_rank)
else:
hvd.join()

def barrier(self, *args: Any, **kwargs: Any) -> None:
if distributed_available():
self._join()

def broadcast(self, obj: object, src: int = 0) -> object:
return hvd.broadcast_object(obj, src)

def all_gather(
self, result: Union[torch.Tensor], process_group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False
) -> List[torch.Tensor]:
if process_group is not None and process_group != dist_group.WORLD:
raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")

if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)

# sync and gather all
self._join()
gathered = hvd.allgather(result)
gathered_result = list(gathered.split(1, dim=0))
return gathered_result

def reduce(
self,
tensor: Union[torch.Tensor, Any],
process_group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[torch.Tensor, Any]:
if process_group is not None:
raise ValueError(
"Horovod does not support allreduce using a subcommunicator at this time. Unset `process_group`."
)

if reduce_op in (None, "avg", "mean"):
reduce_op = hvd.Average
elif reduce_op in ("sum", ReduceOp.SUM):
reduce_op = hvd.Sum
else:
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")

# sync all processes before reduction
self._join()
return hvd.allreduce(tensor, op=reduce_op)
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import torch

from pytorch_lightning.plugins.collective import CollectivePlugin


class SingleDeviceCollective(CollectivePlugin):
"""Collective interface for single device training type plugins."""

def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
pass

def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def all_gather(
self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
return tensor

def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
return tensor
108 changes: 108 additions & 0 deletions pytorch_lightning/plugins/collective/torch_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import torch
import torch.distributed

from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.collective import CollectivePlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class TorchCollective(CollectivePlugin):
"""Collective interfaces for PyTorch.

Mainly used by DDP, DDPSpawn, DP and DDP2.
"""

def __init__(
self,
local_reduce: bool = False,
rank: Optional[int] = None,
device: Optional[Union[str, torch.device]] = torch.device("cpu"),
device_id: Optional[int] = None,
) -> None:
"""
Note:
DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False
DP run reduce in on node, local_reduce = True
DDP2 behaves like DP in one node, local_reduce = True

local_reduce set in Plugins.setup() functions
"""
self.local_reduce = local_reduce
self.rank = rank
self.device = device
self.device_id = device_id

def barrier(self, *args: Any, **kwargs: Any) -> None:
if not distributed_available():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self.device_id)
else:
torch.distributed.barrier()

def broadcast(self, obj: Any, src: int = 0) -> Any:
if not distributed_available():
return obj
obj = [obj]
if self.rank != 0:
obj = [None] * len(obj)
broadcast_object_list(obj, src, group=dist_group.WORLD)
return obj[0]

def all_gather(
self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
return all_gather_ddp_if_available(tensor, group=process_group, sync_grads=sync_grads)

def reduce(
self,
tensor: Union[torch.Tensor, _METRIC_COLLECTION],
process_group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[torch.Tensor, _METRIC_COLLECTION]:
"""Reduces the given tensor (e.g. across GPUs/processes)

If local_reduce = True (dp and ddp2), reduces tensor from all local processes.

If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes

Args:
tensor: the tensor to sync and reduce
process_group: the process group to reduce
reduce_op: the reduction operation. Defaults to 'mean'.
Can also be a string 'sum' or ReduceOp.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if self.local_reduce:

def mean(t: torch.Tensor) -> torch.Tensor:
original_dtype = t.dtype
return t.float().mean().to(original_dtype)

return apply_to_collection(tensor, torch.Tensor, mean)

if isinstance(tensor, torch.Tensor):
tensor = sync_ddp_if_available(tensor, process_group, reduce_op=reduce_op)
return tensor
Loading