Skip to content

Commit 558956d

Browse files
committed
2/n Consolidate collective functions - collective base and subclasses
1 parent ca90f68 commit 558956d

File tree

5 files changed

+37
-18
lines changed

5 files changed

+37
-18
lines changed

pytorch_lightning/plugins/collective/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401
15-
from pytorch_lightning.plugins.collective.single_device_collective import SingleNodeCollective # noqa: F401
16-
from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401
1715
from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401
16+
from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401
17+
from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401
1818
from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401

pytorch_lightning/plugins/collective/horovod_collective.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,19 @@ class HorovodCollective(Collective):
3131
def __init__(
3232
self,
3333
on_gpu: Optional[bool] = False,
34-
local_rank: Optional[int] = 0,
34+
local_rank: int = 0,
3535
):
36-
self._on_gpu = on_gpu
37-
self._local_rank = local_rank
36+
self.on_gpu = on_gpu
37+
self.local_rank = local_rank
3838

3939
def join(self) -> None:
4040
"""Horovod function that indicates that the rank finished processing data.
4141
4242
All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread
4343
until all ranks join.
4444
"""
45-
if self._on_gpu:
46-
hvd.join(self._local_rank)
45+
if self.on_gpu:
46+
hvd.join(self.local_rank)
4747
else:
4848
hvd.join()
4949

pytorch_lightning/plugins/collective/single_device_collective.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytorch_lightning.plugins.collective import Collective
1919

2020

21-
class SingleNodeCollective(Collective):
21+
class SingleDeviceCollective(Collective):
2222
"""Collective interface for single device training type plugins."""
2323

2424
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:

pytorch_lightning/plugins/collective/torch_collective.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@
2929
class TorchCollective(Collective):
3030
"""Collective interface for DDP, DDPSpawn, DP and DDP2."""
3131

32-
def __init__(self, local_reduce: bool = False, rank=None, device=None):
32+
def __init__(
33+
self,
34+
local_reduce: bool = False,
35+
rank: Optional[int] = None,
36+
device: Optional[Union[str, torch.device]] = torch.device("cpu"),
37+
device_id: Optional[int] = None,
38+
world_size: int = 1,
39+
):
3340
""".. note::
3441
3542
DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False
@@ -38,19 +45,21 @@ def __init__(self, local_reduce: bool = False, rank=None, device=None):
3845
3946
local_reduce set in Plugins.setup() functions
4047
"""
41-
self._local_reduce = local_reduce
42-
self._rank = rank
43-
self._device = device
48+
self.local_reduce = local_reduce
49+
self.rank = rank
50+
self.device = device
51+
self.device_id = device_id
52+
self.world_size = world_size
4453

4554
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
4655
if not distributed_available():
4756
return
4857
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
49-
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
58+
torch.distributed.barrier(device_ids=self.device_id)
5059
else:
5160
torch.distributed.barrier()
5261

53-
def broadcast(self, obj: object, src: int = 0) -> object:
62+
def broadcast(self, obj: Any, src: int = 0) -> Any:
5463
if not distributed_available():
5564
return obj
5665
else:
@@ -97,7 +106,7 @@ def mean(t: torch.Tensor) -> torch.Tensor:
97106
return tensor
98107

99108
def reduce_boolean_decision(self, decision: bool) -> bool:
100-
decision = torch.tensor(int(decision), device=self.lightning_module.device)
109+
decision = torch.tensor(int(decision), device=self.device)
101110
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
102111
decision = bool(decision == self.world_size)
103112
return decision

pytorch_lightning/plugins/collective/tpu_collective.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,22 @@
2525
import torch_xla.core.xla_model as xm
2626
from torch_xla.core.xla_model import rendezvous
2727
else:
28-
xm, rendezvous = [None] * 4
28+
xm, rendezvous = [None] * 2
2929

3030

3131
class TPUCollective(Collective):
3232
"""Collective interface for TPU and TPUSpawning training type plugins."""
3333

34+
def __init__(
35+
self,
36+
device: Union[str, torch.device] = torch.device("xla"),
37+
root_device: torch.device = xm.xla_device(),
38+
world_size: int = xm.xrt_world_size(),
39+
):
40+
self.device = device
41+
self.root_device = root_device
42+
self.world_size = world_size
43+
3444
def barrier(self, name: Optional[str] = None) -> None:
3545
if self.is_distributed:
3646
rendezvous(name)
@@ -59,11 +69,11 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
5969
"""
6070
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
6171
tensor = tensor.unsqueeze(0)
62-
return self._xm.all_gather(tensor)
72+
return xm.all_gather(tensor)
6373

6474
def reduce(self, output: Any, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Any:
6575
if not isinstance(output, torch.Tensor):
66-
output = torch.tensor(output, device=self.lightning_module.device)
76+
output = torch.tensor(output, device=self.device)
6777

6878
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
6979
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")

0 commit comments

Comments
 (0)