-
Notifications
You must be signed in to change notification settings - Fork 3.5k
2/n Consolidate collective functions - collective base and subclasses #9414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3b67185
4ddaf0f
9766a72
d2f6c53
cea07e1
c1ececd
97e2a8d
763352c
393adcc
5f7febf
464fbf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from pytorch_lightning.plugins.collective.collective_plugin import CollectivePlugin # noqa: F401 | ||
from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401 | ||
from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401 | ||
from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401 | ||
from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from abc import ABC, abstractmethod | ||
from typing import Any, List, Optional, Union | ||
|
||
import torch | ||
|
||
from pytorch_lightning.utilities.distributed import ReduceOp | ||
|
||
|
||
class CollectivePlugin(ABC): | ||
"""Interface for collective functions. | ||
|
||
Lightning collective supports communications between multiple processes and multiple nodes, provides routines such | ||
as barrier, broadcast, all_gather, and reduce | ||
|
||
.. note:: | ||
This API is experimental/in-beta and subject to change | ||
""" | ||
|
||
@abstractmethod | ||
def barrier(self, name: Optional[str] = None) -> None: | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Synchronizes all processes which blocks processes until the whole group enters this function. | ||
|
||
Args: | ||
name: a str pass into barrier. Only torch xla respect this param | ||
""" | ||
|
||
@abstractmethod | ||
def broadcast(self, obj: object, src: int = 0) -> object: | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Broadcasts an object to all processes. | ||
|
||
Args: | ||
obj: the object to broadcast | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use a typevar to ensure that the type of the input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not necessary true.
|
||
src: source rank. | ||
""" | ||
|
||
@abstractmethod | ||
def all_gather( | ||
self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False | ||
) -> Union[List[torch.Tensor], torch.Tensor]: | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Perform a all_gather on all processes. | ||
|
||
Args: | ||
tensor: the tensor to all_gather | ||
process_group: the process group to gather results from | ||
sync_grads: flag that allows users to synchronize gradients for all_gather op | ||
|
||
Returns: a tensor (torch distributed) or a list of tensor (horovod) | ||
""" | ||
|
||
@abstractmethod | ||
def reduce( | ||
self, | ||
tensor: Union[torch.Tensor, Any], | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
process_group: Optional[Any] = None, | ||
reduce_op: Optional[Union[ReduceOp, str]] = "mean", | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> Union[torch.Tensor, Any]: | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Reduces the given tensor (e.g. across GPUs/processes). | ||
|
||
Args: | ||
tensor: the tensor to sync and reduce | ||
process_group: the process group to reduce | ||
reduce_op: the reduction operation. Defaults to 'mean'. | ||
Can also be a string 'sum' or ReduceOp. | ||
*args: plugin-specific positional arguments | ||
**kwargs: plugin-specific keyword arguments | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, List, Optional, Union | ||
|
||
import torch | ||
|
||
from pytorch_lightning.plugins.collective import CollectivePlugin | ||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE | ||
from pytorch_lightning.utilities.distributed import distributed_available | ||
from pytorch_lightning.utilities.distributed import group as dist_group | ||
from pytorch_lightning.utilities.distributed import ReduceOp | ||
|
||
if _HOROVOD_AVAILABLE: | ||
import horovod.torch as hvd | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class HorovodCollective(CollectivePlugin): | ||
"""Collective interface for Horovod training type plugins.""" | ||
|
||
def __init__( | ||
self, | ||
on_gpu: bool = False, | ||
local_rank: int = 0, | ||
) -> None: | ||
self.on_gpu = on_gpu | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.local_rank = local_rank | ||
|
||
def _join(self) -> None: | ||
"""Horovod function that indicates that the rank finished processing data. | ||
|
||
All ranks that did not call join() continue to process allreduce operations. This function blocks the Python | ||
thread until all ranks join. | ||
""" | ||
if self.on_gpu: | ||
hvd.join(self.local_rank) | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
hvd.join() | ||
|
||
def barrier(self, *args: Any, **kwargs: Any) -> None: | ||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if distributed_available(): | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._join() | ||
|
||
def broadcast(self, obj: object, src: int = 0) -> object: | ||
return hvd.broadcast_object(obj, src) | ||
|
||
def all_gather( | ||
self, result: Union[torch.Tensor], process_group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False | ||
) -> List[torch.Tensor]: | ||
if process_group is not None and process_group != dist_group.WORLD: | ||
raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.") | ||
|
||
if len(result.shape) == 0: | ||
# Convert scalars to single dimension tensors | ||
result = result.reshape(1) | ||
|
||
# sync and gather all | ||
self._join() | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
gathered = hvd.allgather(result) | ||
gathered_result = list(gathered.split(1, dim=0)) | ||
return gathered_result | ||
|
||
def reduce( | ||
self, | ||
tensor: Union[torch.Tensor, Any], | ||
process_group: Optional[Any] = None, | ||
reduce_op: Optional[Union[ReduceOp, str]] = "mean", | ||
) -> Union[torch.Tensor, Any]: | ||
if process_group is not None: | ||
raise ValueError( | ||
"Horovod does not support allreduce using a subcommunicator at this time. Unset `process_group`." | ||
) | ||
|
||
if reduce_op in (None, "avg", "mean"): | ||
reduce_op = hvd.Average | ||
elif reduce_op in ("sum", ReduceOp.SUM): | ||
reduce_op = hvd.Sum | ||
else: | ||
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# sync all processes before reduction | ||
self._join() | ||
return hvd.allreduce(tensor, op=reduce_op) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Optional, Union | ||
|
||
import torch | ||
|
||
from pytorch_lightning.plugins.collective import CollectivePlugin | ||
|
||
|
||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class SingleDeviceCollective(CollectivePlugin): | ||
"""Collective interface for single device training type plugins.""" | ||
|
||
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: | ||
pass | ||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def broadcast(self, obj: object, src: int = 0) -> object: | ||
return obj | ||
|
||
def all_gather( | ||
self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False | ||
) -> torch.Tensor: | ||
return tensor | ||
|
||
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: | ||
return tensor |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Optional, Union | ||
|
||
import torch | ||
import torch.distributed | ||
|
||
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list | ||
from pytorch_lightning.plugins.collective import CollectivePlugin | ||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 | ||
from pytorch_lightning.utilities.apply_func import apply_to_collection | ||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, distributed_available | ||
from pytorch_lightning.utilities.distributed import group as dist_group | ||
from pytorch_lightning.utilities.distributed import ReduceOp, sync_ddp_if_available | ||
from pytorch_lightning.utilities.types import _METRIC_COLLECTION | ||
|
||
|
||
class TorchCollective(CollectivePlugin): | ||
"""Collective interfaces for PyTorch. | ||
|
||
Mainly used by DDP, DDPSpawn, DP and DDP2. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
local_reduce: bool = False, | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rank: Optional[int] = None, | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
device: Optional[Union[str, torch.device]] = torch.device("cpu"), | ||
device_id: Optional[int] = None, | ||
) -> None: | ||
""" | ||
Note: | ||
DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False | ||
DP run reduce in on node, local_reduce = True | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DDP2 behaves like DP in one node, local_reduce = True | ||
|
||
local_reduce set in Plugins.setup() functions | ||
""" | ||
self.local_reduce = local_reduce | ||
self.rank = rank | ||
self.device = device | ||
self.device_id = device_id | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def barrier(self, *args: Any, **kwargs: Any) -> None: | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not distributed_available(): | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": | ||
torch.distributed.barrier(device_ids=self.device_id) | ||
else: | ||
torch.distributed.barrier() | ||
|
||
def broadcast(self, obj: Any, src: int = 0) -> Any: | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not distributed_available(): | ||
return obj | ||
obj = [obj] | ||
if self.rank != 0: | ||
obj = [None] * len(obj) | ||
broadcast_object_list(obj, src, group=dist_group.WORLD) | ||
return obj[0] | ||
|
||
def all_gather( | ||
self, tensor: torch.Tensor, process_group: Optional[Any] = None, sync_grads: bool = False | ||
) -> torch.Tensor: | ||
return all_gather_ddp_if_available(tensor, group=process_group, sync_grads=sync_grads) | ||
|
||
def reduce( | ||
self, | ||
tensor: Union[torch.Tensor, _METRIC_COLLECTION], | ||
process_group: Optional[Any] = None, | ||
reduce_op: Optional[Union[ReduceOp, str]] = "mean", | ||
) -> Union[torch.Tensor, _METRIC_COLLECTION]: | ||
"""Reduces the given tensor (e.g. across GPUs/processes) | ||
|
||
If local_reduce = True (dp and ddp2), reduces tensor from all local processes. | ||
|
||
If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes | ||
|
||
Args: | ||
tensor: the tensor to sync and reduce | ||
process_group: the process group to reduce | ||
reduce_op: the reduction operation. Defaults to 'mean'. | ||
Can also be a string 'sum' or ReduceOp. | ||
|
||
Return: | ||
reduced value, except when the input was not a tensor the output remains is unchanged | ||
""" | ||
if self.local_reduce: | ||
four4fish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def mean(t: torch.Tensor) -> torch.Tensor: | ||
original_dtype = t.dtype | ||
return t.float().mean().to(original_dtype) | ||
|
||
return apply_to_collection(tensor, torch.Tensor, mean) | ||
|
||
if isinstance(tensor, torch.Tensor): | ||
tensor = sync_ddp_if_available(tensor, process_group, reduce_op=reduce_op) | ||
return tensor |
Uh oh!
There was an error while loading. Please reload this page.