Skip to content

Commit 7c0d4ab

Browse files
committed
2/n Consolidate collective functions - collective base and subclasses
1 parent 089ae9b commit 7c0d4ab

File tree

7 files changed

+397
-0
lines changed

7 files changed

+397
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
116116
- Added `inference_mode` for evaluation and prediction ([8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))
117117

118118

119+
- Add collective base class and subclasses ([#9414](https://github.com/PyTorchLightning/pytorch-lightning/pull/9414))
120+
121+
119122
### Changed
120123

121124
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401
15+
from pytorch_lightning.plugins.collective.single_device_collective import SingleNodeCollective # noqa: F401
16+
from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401
17+
from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401
18+
from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
from typing import Any, Optional, Union
16+
17+
import torch
18+
19+
20+
class Collective(ABC):
21+
"""Base class for collective functions for training type plugins."""
22+
23+
@abstractmethod
24+
def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None:
25+
"""Forces all possibly joined processes to wait for each other."""
26+
27+
@abstractmethod
28+
def broadcast(self, obj: object, src: int = 0) -> object:
29+
"""Broadcasts an object to all processes."""
30+
31+
@abstractmethod
32+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
33+
"""Perform a all_gather on all processes."""
34+
35+
@abstractmethod
36+
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
37+
"""Reduces the given tensor (e.g. across GPUs/processes).
38+
39+
Args:
40+
tensor: the tensor to sync and reduce
41+
*args: plugin-specific positional arguments
42+
**kwargs: plugin-specific keyword arguments
43+
"""
44+
45+
def reduce_boolean_decision(self, decision: bool) -> bool:
46+
"""Reduce the early stopping decision across all processes."""
47+
return decision
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import io
15+
from typing import Any, Optional, Union
16+
17+
import torch
18+
19+
from pytorch_lightning.plugins.collective import Collective
20+
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
21+
from pytorch_lightning.utilities.distributed import ReduceOp
22+
from pytorch_lightning.utilities.types import _TPU_AVAILABLE
23+
24+
if _TPU_AVAILABLE:
25+
import torch_xla.core.xla_model as xm
26+
from torch_xla.core.xla_model import rendezvous
27+
else:
28+
xm, rendezvous = [None] * 4
29+
30+
if _HOROVOD_AVAILABLE:
31+
import horovod.torch as hvd
32+
33+
34+
class HorovodCollective(Collective):
35+
"""Collective interface for Horovod training type plugins."""
36+
37+
def __init__(
38+
self,
39+
on_gpu: Optional[bool] = False,
40+
local_rank: Optional[int] = 0,
41+
):
42+
self._on_gpu = on_gpu
43+
self._local_rank = local_rank
44+
45+
def join(self):
46+
"""Horovod function that indicates that the rank finished processing data.
47+
48+
All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread
49+
until all ranks join.
50+
"""
51+
if self.on_gpu:
52+
hvd.join(self.local_rank)
53+
else:
54+
hvd.join()
55+
56+
def barrier(self, name: Optional[str] = None) -> None:
57+
if self.is_distributed:
58+
rendezvous(name)
59+
60+
def broadcast(self, obj: object, src: int = 0) -> object:
61+
if not self.is_distributed:
62+
return obj
63+
buffer = io.BytesIO()
64+
torch.save(obj, buffer)
65+
data = bytearray(buffer.getbuffer())
66+
data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
67+
data = xm.all_gather(data_tensor)
68+
buffer = io.BytesIO(data.cpu().byte().numpy())
69+
obj = torch.load(buffer)
70+
return obj
71+
72+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
73+
"""
74+
Function to gather a tensor from several distributed processes
75+
Args:
76+
tensor: tensor of shape (batch, ...)
77+
group: not available with TPUs
78+
sync_grads: not available with TPUs
79+
Return:
80+
A tensor of shape (world_size, batch, ...)
81+
"""
82+
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
83+
tensor = tensor.unsqueeze(0)
84+
return self._xm.all_gather(tensor)
85+
86+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
87+
"""Reduces a tensor from several distributed processes to one aggregated tensor.
88+
89+
Args:
90+
tensor: the tensor to sync and reduce
91+
group: the process group to gather results from. Defaults to all processes (world)
92+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
93+
Can also be a string 'sum' to calculate the sum during reduction.
94+
95+
Return:
96+
reduced value, except when the input was not a tensor the output remains is unchanged
97+
"""
98+
if group is not None:
99+
raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.")
100+
101+
if reduce_op in (None, "avg", "mean"):
102+
reduce_op = hvd.Average
103+
elif reduce_op in ("sum", ReduceOp.SUM):
104+
reduce_op = hvd.Sum
105+
else:
106+
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")
107+
108+
# sync all processes before reduction
109+
self.join()
110+
return hvd.allreduce(tensor, op=reduce_op)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Optional, Union
15+
16+
import torch
17+
18+
from pytorch_lightning.plugins.collective import Collective
19+
20+
21+
class SingleNodeCollective(Collective):
22+
"""Collective interface for single device training type plugins."""
23+
24+
def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None:
25+
"""Forces all possibly joined processes to wait for each other."""
26+
pass
27+
28+
def broadcast(self, obj: object, src: int = 0) -> object:
29+
"""Broadcasts an object to all processes."""
30+
return obj
31+
32+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
33+
"""Perform a all_gather on all processes."""
34+
return tensor
35+
36+
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
37+
"""Reduces the given tensor (e.g. across GPUs/processes).
38+
39+
Args:
40+
tensor: the tensor to sync and reduce
41+
*args: plugin-specific positional arguments
42+
**kwargs: plugin-specific keyword arguments
43+
"""
44+
return tensor
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Optional, Union
15+
16+
import torch
17+
import torch.distributed
18+
19+
from pytorch_lightning.plugins.collective import Collective
20+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
21+
from pytorch_lightning.utilities.apply_func import apply_to_collection
22+
from pytorch_lightning.utilities.distributed import (
23+
all_gather_ddp_if_available,
24+
distributed_available,
25+
ReduceOp,
26+
sync_ddp_if_available,
27+
)
28+
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
29+
30+
31+
class TorchCollective(Collective):
32+
"""Collective interface for DDP, DDPSpawn, DP and DDP2."""
33+
34+
def __init__(self, local_reduce=False):
35+
""".. note::
36+
37+
DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False
38+
DP run reduce in on node, local_reduce = True
39+
DDP2 behaves like DP in one node, local_reduce = True
40+
41+
local_reduce set in Plugins.setup() functions
42+
"""
43+
self.local_reduce = local_reduce
44+
45+
def barrier(self, *args, **kwargs) -> None:
46+
if not distributed_available():
47+
return
48+
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
49+
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
50+
else:
51+
torch.distributed.barrier()
52+
53+
def broadcast(self, obj: object, src: int = 0) -> object:
54+
if not distributed_available():
55+
return obj
56+
return self.dist.broadcast(obj)
57+
58+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
59+
"""Perform a all_gather on all processes."""
60+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
61+
62+
def reduce(
63+
self, tensor: _METRIC_COLLECTION, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean"
64+
) -> torch.Tensor:
65+
"""Reduces the given tensor (e.g. across GPUs/processes)
66+
67+
If local_reduce = True (dp and ddp2), reduces tensor from all local processes.
68+
69+
If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes
70+
Args:
71+
tensor: the tensor to sync and reduce
72+
group: the process group to gather results from. Defaults to all processes (world)
73+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
74+
Can also be a string 'sum' to calculate the sum during reduction.
75+
76+
Return:
77+
reduced value, except when the input was not a tensor the output remains is unchanged
78+
"""
79+
if self.local_reduce:
80+
81+
def mean(t: torch.Tensor) -> torch.Tensor:
82+
original_dtype = t.dtype
83+
return t.float().mean().to(original_dtype)
84+
85+
return apply_to_collection(tensor, torch.Tensor, mean)
86+
87+
if isinstance(tensor, torch.Tensor):
88+
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
89+
return tensor
90+
91+
def reduce_boolean_decision(self, decision: bool) -> bool:
92+
decision = torch.tensor(int(decision), device=self.lightning_module.device)
93+
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
94+
decision = bool(decision == self.world_size)
95+
return decision

0 commit comments

Comments
 (0)