Skip to content

Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly #9691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 25, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616))


- Deprecate `LightningDistributed` and move the broadcast logic to `DDPPlugin` and `DDPSpawnPlugin` directly ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down Expand Up @@ -395,6 +398,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652))


- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))


## [1.4.8] - 2021-09-22

- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/distributed/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,22 @@
from typing import Any

from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.distributed import group as _group


class LightningDistributed:
"""
.. deprecated:: v1.5
This class is deprecated in v1.5 and will be removed in v1.7.
The broadcast logic will be moved to the :class:`DDPPlugin` and :class`DDPSpawnPlugin` classes.
"""

def __init__(self, rank=None, device=None):
rank_zero_deprecation(
"LightningDistributed is deprecated in v1.5 and will be removed in v1.7."
"Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations."
)
self.rank = rank
self.device = device

Expand Down
21 changes: 9 additions & 12 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand All @@ -48,13 +48,9 @@
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import (
distributed_available,
init_ddp_connection,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -116,7 +112,6 @@ def __init__(
" Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self.dist = LightningDistributed()
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._ddp_kwargs = kwargs
self._task_idx = None
Expand Down Expand Up @@ -270,8 +265,6 @@ def setup_distributed(self):
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

def _check_can_spawn_children(self):
if self.local_rank != 0:
Expand Down Expand Up @@ -403,7 +396,11 @@ def barrier(self, *args, **kwargs) -> None:
torch.distributed.barrier()

def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)
obj = [obj]
if self.global_rank != src:
obj = [None] * len(obj)
broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

def pre_backward(self, closure_loss: torch.Tensor) -> None:
"""Run before precision plugin executes backward."""
Expand Down
23 changes: 9 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand All @@ -40,13 +40,9 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
distributed_available,
init_ddp_connection,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -93,7 +89,6 @@ def __init__(
)
self._sync_batchnorm = sync_batchnorm or False
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
Expand Down Expand Up @@ -193,10 +188,6 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

# move the model to the correct device
self.model_to_device()

Expand Down Expand Up @@ -324,7 +315,11 @@ def barrier(self, *args, **kwargs) -> None:
def broadcast(self, obj: object, src: int = 0) -> object:
if not distributed_available():
return obj
return self.dist.broadcast(obj)
obj = [obj]
if self.global_rank != src:
obj = [None] * len(obj)
broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

def model_to_device(self):
if self.root_device.type == "cuda":
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,6 @@ def setup_distributed(self):

self._init_deepspeed_distributed()

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
if not self._config_initialized:
self._format_config()
self._config_initialized = True
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ omit =
pytorch_lightning/cluster_environments/*.py
pytorch_lightning/utilities/distributed.py
pytorch_lightning/tuner/auto_gpu_select.py
pytorch_lightning/distributed/dist.py


[flake8]
Expand Down
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,10 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir):
):
logger = LoggerCollection([logger])
logger.close()


def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."):
from pytorch_lightning.distributed.dist import LightningDistributed

_ = LightningDistributed()