Skip to content

3/n Consolidate collective functions - Integrate with TrainingTypePlugin #9472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344))


- Add collective base class and subclasses ([#9414](https://github.com/PyTorchLightning/pytorch-lightning/pull/9414))


### Changed

- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Expand Down Expand Up @@ -177,6 +180,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360))


- Integrate `collective` class with `TrainingTypePlugin` ([#9472](https://github.com/PyTorchLightning/pytorch-lightning/pull/9472))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
return self.training_type_plugin.lightning_module_state_dict()

def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)
self.training_type_plugin.collective.barrier(name=name)

def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if
Expand All @@ -348,7 +348,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj: Object to broadcast to all process, usually a tensor or collection of tensors.
src: The source rank of which the object will be broadcast from
"""
return self.training_type_plugin.broadcast(obj, src)
return self.training_type_plugin.collective.broadcast(obj, src)

def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Function to gather a tensor from several distributed processes.
Expand All @@ -361,7 +361,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
Return:
A tensor of shape (world_size, batch, ...)
"""
return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads)
return self.training_type_plugin.collective.all_gather(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
should_stop, reason = self._evaluate_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
should_stop = trainer.training_type_plugin.collective.reduce_boolean_decision(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def on_train_batch_end(
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
skip_time = trainer.training_type_plugin.broadcast(skip_time)
skip_time = trainer.training_type_plugin.collective.broadcast(skip_time)

if skip_batch and skip_time:
return
Expand Down Expand Up @@ -509,7 +509,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)
should_update_best_and_save = trainer.training_type_plugin.collective.reduce_boolean_decision(
should_update_best_and_save
)

return should_update_best_and_save

Expand Down Expand Up @@ -612,7 +614,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path)
ckpt_path = trainer.training_type_plugin.collective.broadcast(ckpt_path)

self.dirpath = ckpt_path

Expand Down Expand Up @@ -748,4 +750,4 @@ def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
state to diverge between ranks."""
exists = self._fs.exists(filepath)
return trainer.training_type_plugin.broadcast(exists)
return trainer.training_type_plugin.collective.broadcast(exists)
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def on_train_start(self, trainer, pl_module) -> None:
)

memory_info = xm.get_memory_info(pl_module.device)
total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001
total_memory = trainer.training_type_plugin.collective.reduce(memory_info["kb_total"]) * 0.001
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")

def on_train_epoch_start(self, trainer, pl_module) -> None:
Expand All @@ -81,9 +81,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None:
free_memory = memory_info["kb_free"]
peak_memory = memory_info["kb_total"] - free_memory

free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001
peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
free_memory = trainer.training_type_plugin.collective.reduce(free_memory) * 0.001
peak_memory = trainer.training_type_plugin.collective.reduce(peak_memory) * 0.001
epoch_time = trainer.training_type_plugin.collective.reduce(epoch_time)

logs["avg. free memory (MB)"] = free_memory
logs["avg. peak memory (MB)"] = peak_memory
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def log(
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist_fn=self.trainer.training_type_plugin.collective.reduce or sync_ddp,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def _load_from_state_dict(
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
state_dict[prefix + k],
metrics=metric_attributes,
sync_fn=self.trainer.training_type_plugin.collective.reduce,
)

if not self.trainer.is_global_zero:
Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/plugins/collective/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.collective.collective_plugin import Collective # noqa: F401
from pytorch_lightning.plugins.collective.horovod_collective import HorovodCollective # noqa: F401
from pytorch_lightning.plugins.collective.single_device_collective import SingleDeviceCollective # noqa: F401
from pytorch_lightning.plugins.collective.torch_collective import TorchCollective # noqa: F401
from pytorch_lightning.plugins.collective.tpu_collective import TPUCollective # noqa: F401
56 changes: 56 additions & 0 deletions pytorch_lightning/plugins/collective/collective_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union

import torch

from pytorch_lightning.utilities.distributed import ReduceOp


class Collective(ABC):
"""Base class for collective functions for training type plugins."""

@abstractmethod
def barrier(self, name: Optional[str] = None) -> None:
"""Forces all possibly joined processes to wait for each other."""

@abstractmethod
def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes."""

@abstractmethod
def all_gather(
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Perform a all_gather on all processes."""

@abstractmethod
def reduce(
self,
tensor: Union[torch.Tensor, Any],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[torch.Tensor, Any]:
"""Reduces the given tensor (e.g. across GPUs/processes).

Args:
tensor: the tensor to sync and reduce
*args: plugin-specific positional arguments
**kwargs: plugin-specific keyword arguments
"""

@abstractmethod
def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
107 changes: 107 additions & 0 deletions pytorch_lightning/plugins/collective/horovod_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Union

import torch

from pytorch_lightning.plugins.collective import Collective
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd


class HorovodCollective(Collective):
"""Collective interface for Horovod training type plugins."""

def __init__(
self,
on_gpu: Optional[bool] = False,
local_rank: int = 0,
) -> None:
self.on_gpu = on_gpu
self.local_rank = local_rank

def join(self) -> None:
"""Horovod function that indicates that the rank finished processing data.

All ranks that did not call join() continue to process allreduce operations. This function blocks the Python
thread until all ranks join.
"""
if self.on_gpu:
hvd.join(self.local_rank)
else:
hvd.join()

def barrier(self, *args: Any, **kwargs: Any) -> None:
if distributed_available():
self.join()

def broadcast(self, obj: object, src: int = 0) -> object:
obj = hvd.broadcast_object(obj, src)
return obj

def all_gather(
self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False
) -> List[torch.Tensor]:
if group is not None and group != dist_group.WORLD:
raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")

if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)

# sync and gather all
self.join()
gathered = hvd.allgather(result)
gathered_result = list(gathered.split(1, dim=0))
return gathered_result

def reduce(
self,
tensor: Union[torch.Tensor, Any],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[torch.Tensor, Any]:
"""Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if group is not None:
raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.")

if reduce_op in (None, "avg", "mean"):
reduce_op = hvd.Average
elif reduce_op in ("sum", ReduceOp.SUM):
reduce_op = hvd.Sum
else:
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")

# sync all processes before reduction
self.join()
return hvd.allreduce(tensor, op=reduce_op)

def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
return decision
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import torch

from pytorch_lightning.plugins.collective import Collective


class SingleDeviceCollective(Collective):
"""Collective interface for single device training type plugins."""

def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
"""Forces all possibly joined processes to wait for each other."""
pass

def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes."""
return obj

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes."""
return tensor

def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
"""Reduces the given tensor (e.g. across GPUs/processes).

Args:
tensor: the tensor to sync and reduce
*args: plugin-specific positional arguments
**kwargs: plugin-specific keyword arguments
"""
return tensor

def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
return decision
Loading