Skip to content

Move add_to_queue/get_from_queue to DDPSpawnPlugin #9118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
db65aa1
Move add_to_queue/get_from_queue to ddp_spawn.py
daniellepintz Aug 21, 2021
4ca811c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2021
785be9e
Update changelog and add todos
daniellepintz Aug 25, 2021
e9ead82
Merge branch 'ddp_spawn' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Aug 25, 2021
edc0ad6
wip
daniellepintz Sep 1, 2021
c782bd6
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 1, 2021
ccf44e4
address comments
daniellepintz Sep 1, 2021
ced51e1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 2, 2021
11e9292
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2021
238587a
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 2, 2021
c7be9ea
Merge branch 'ddp_spawn' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Sep 2, 2021
7edf7f8
update import path
daniellepintz Sep 2, 2021
9117350
update import path
daniellepintz Sep 2, 2021
a177c49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2021
12cc5d2
resolve circular reference
tchaton Sep 2, 2021
705bba2
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 2, 2021
3f7b4f1
Merge branch 'ddp_spawn' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Sep 2, 2021
e2a7a51
address comment
daniellepintz Sep 3, 2021
7f48f26
Update CHANGELOG.md
daniellepintz Sep 3, 2021
4d591e8
Update CHANGELOG.md
daniellepintz Sep 3, 2021
a3936e6
Apply suggestions from code review
daniellepintz Sep 6, 2021
546a28b
address comments
daniellepintz Sep 6, 2021
e5d2289
Apply suggestions from code review
Borda Sep 6, 2021
e561650
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 6, 2021
dfed884
Merge branch 'ddp_spawn' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Sep 6, 2021
446fafd
update test
daniellepintz Sep 6, 2021
113645d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2021
a1a439d
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Sep 10, 2021
ddd4e57
Merge branch 'ddp_spawn' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Sep 10, 2021
e73b419
RunIf(skip_windows=True)
daniellepintz Sep 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))


- Deprecated `add_to_queue`, `get_from_queue` from `LightningModule` in favor of corresponding methods in the `DDPSpawnPlugin` ([9118](https://github.com/PyTorchLightning/pytorch-lightning/pull/9118))


- Deprecated `LightningModule.get_progress_bar_dict` and `Trainer.progress_bar_dict` in favor of `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and `ProgressBarBase.get_metrics` ([#8985](https://github.com/PyTorchLightning/pytorch-lightning/pull/8985))


Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
Expand Down Expand Up @@ -116,7 +115,7 @@ def dispatch(self, trainer: "pl.Trainer") -> None:

def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.training_type_plugin.post_dispatch(trainer)
self.precision_plugin.post_dispatch()

@property
Expand Down
23 changes: 13 additions & 10 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
Expand Down Expand Up @@ -1905,24 +1905,27 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:

Args:
queue: the instance of the queue to append the data.

.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.add_to_queue`
and will be removed in v1.7.
"""
callback_metrics: dict = apply_to_collection(
self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)

def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.

Args:
queue: the instance of the queue from where to get the data.

.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.get_from_queue`
and will be removed in v1.7.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
self.trainer.callback_metrics.update(
apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))
)
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
self.trainer.training_type_plugin.get_from_queue(self.trainer, queue)

@contextmanager
def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import torch.distributed
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
Expand Down Expand Up @@ -385,7 +386,7 @@ def pre_dispatch(self):
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

def post_dispatch(self) -> None:
def post_dispatch(self, trainer: "pl.Trainer") -> None:
self.cluster_environment.teardown()

def barrier(self, *args, **kwargs) -> None:
Expand Down
41 changes: 38 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from multiprocessing.queues import SimpleQueue
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
import torch.distributed
import torch.multiprocessing as mp
Expand All @@ -36,6 +37,7 @@
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
Expand All @@ -45,6 +47,7 @@
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand Down Expand Up @@ -215,14 +218,18 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()

def post_dispatch(self):
def post_dispatch(self, trainer: "pl.Trainer"):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()
# get the `callback_metrics` and set it to the trainer
# only in case the user does not override it.
self.lightning_module.get_from_queue(self.mp_queue)
# TODO: Remove the if in v1.7
if is_overridden("get_from_queue", self.lightning_module):
self.lightning_module.get_from_queue(self.mp_queue)
else:
self.get_from_queue(trainer, self.mp_queue)

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(best_path, last_path)
Expand Down Expand Up @@ -288,7 +295,12 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
else:
self.add_to_queue(trainer, self.mp_queue)

def __recover_child_process_weights(self, best_path, last_path):
# transfer back the best path to the trainer
Expand Down Expand Up @@ -362,6 +374,29 @@ def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True

def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.

Args:
queue: the instance of the queue to append the data.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)

def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.

Args:
queue: the instance of the queue from where to get the data.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))

@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,5 @@ def pre_dispatch(self) -> None:
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something at trainer run_stage starts."""

def post_dispatch(self) -> None:
def post_dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
elif self.trainer.state.fn == TrainerFn.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)
self._check_add_get_queue(model)
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
self._check_progress_bar(model)
# TODO: Delete _check_on_keyboard_interrupt in v1.7
Expand Down Expand Up @@ -219,6 +220,24 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod
"is incompatible with `truncated_bptt_steps > 0`."
)

def _check_add_get_queue(self, model: "pl.LightningModule") -> None:
r"""
Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning.

Args:
model: The lightning module
"""
if is_overridden("add_to_queue", model):
rank_zero_deprecation(
"The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
"favor of `DDPSpawnPlugin.add_to_queue`"
)
if is_overridden("get_from_queue", model):
rank_zero_deprecation(
"The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
"favor of `DDPSpawnPlugin.get_from_queue`"
)

def _check_on_keyboard_interrupt(self) -> None:
"""Checks if on_keyboard_interrupt is overriden and sends a deprecation warning."""
for callback in self.trainer.callbacks:
Expand Down
26 changes: 26 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from unittest import mock

import pytest
import torch

from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.loggers import TestTubeLogger
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf


def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir):
Expand Down Expand Up @@ -192,3 +194,27 @@ def on_keyboard_interrupt(self, trainer, pl_module):
def test_v1_7_0_process_position_trainer_constructor(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"):
_ = Trainer(process_position=5)


class BoringCallbackDDPSpawnModel(BoringModel):
def __init__(self):
super().__init__()

def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
queue.put("test_val")
return super().add_to_queue(queue)

def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
self.test_val = queue.get()
return super().get_from_queue(queue)


def test_v1_7_0_deprecate_add_get_queue(tmpdir):
model = BoringCallbackDDPSpawnModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu")

with pytest.deprecated_call(match=r"`LightningModule.add_to_queue` method was deprecated in v1.5"):
trainer.fit(model)

with pytest.deprecated_call(match=r"`LightningModule.get_from_queue` method was deprecated in v1.5"):
trainer.fit(model)
33 changes: 30 additions & 3 deletions tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:

@RunIf(skip_windows=True)
def test_ddp_cpu():
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
"""Tests if device is set correctly when training for DDPSpawnPlugin."""
trainer = Trainer(num_processes=2, fast_dev_run=True)
# assert training type plugin attributes for device setting

Expand All @@ -64,7 +64,8 @@ def test_ddp_cpu():

@RunIf(min_gpus=2)
def test_ddp_spawn_extra_parameters(tmpdir):
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
"""Tests if device is set correctly when training for DDPSpawnPlugin and tests add_to_queue/get_from_queue with
Lightning Module (deprecated way)."""
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")

assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
Expand All @@ -75,12 +76,38 @@ def test_ddp_spawn_extra_parameters(tmpdir):
val_name: str = "val_acc"
model = BoringCallbackDDPSpawnModel(val_name, val)
dm = BoringDataModule()

trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"


class TestDDPSpawnPlugin(DDPSpawnPlugin):
def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
queue.put("new_test_val")
return super().add_to_queue(trainer, queue)

def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None:
self.new_test_val = queue.get()
return super().get_from_queue(trainer, queue)


def test_ddp_spawn_add_get_queue(tmpdir):
"""Tests add_to_queue/get_from_queue with DDPSpawnPlugin."""

ddp_spawn_plugin = TestDDPSpawnPlugin()
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, accelerator="ddp_cpu", plugins=[ddp_spawn_plugin]
)

val: float = 1.0
val_name: str = "val_acc"
model = BoringCallbackDDPSpawnModel(val_name, val)
dm = BoringDataModule()
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert ddp_spawn_plugin.new_test_val == "new_test_val"


class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
Expand Down