Skip to content

Fixes access to callback_metrics in ddp_spawn #7916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0567de9
add spawn_callback_metrics
edgarriba Jun 10, 2021
e5fd47c
use apply_to_collection
edgarriba Jun 10, 2021
d6d6c19
Merge branch 'master' into edgar/feat/spawn_args
edgarriba Jun 10, 2021
6e9715c
make more generic spawn_extra_parameters
edgarriba Jun 10, 2021
a1b3865
generalise a bit more mp_queue parameters
edgarriba Jun 10, 2021
9b5a97c
Merge branch 'master' into edgar/feat/spawn_args
edgarriba Jun 10, 2021
c637998
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
97fd769
implement single and multi gpu tests
edgarriba Jun 10, 2021
2eb0836
fix typing
edgarriba Jun 10, 2021
1e60b9f
Merge branch 'master' into edgar/feat/spawn_args
edgarriba Jun 10, 2021
af4cf60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
3b35572
remove unused import
edgarriba Jun 10, 2021
ee719ee
Update pytorch_lightning/plugins/training_type/ddp_spawn.py
edgarriba Jun 11, 2021
284f057
Update tests/plugins/test_ddp_spawn_plugin.py
edgarriba Jun 14, 2021
37672e8
Update tests/plugins/test_ddp_spawn_plugin.py
edgarriba Jun 14, 2021
cb4ca61
Merge branch 'master' into edgar/feat/spawn_args
edgarriba Jun 14, 2021
a455a55
tensor to numpy conversion for extra parameters
edgarriba Jun 14, 2021
2ea111c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
ba891f2
add callback_metrics in tpu spawn plugin
edgarriba Jun 14, 2021
fa147ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
1c96204
Merge branch 'master' into edgar/feat/spawn_args
edgarriba Jun 14, 2021
1e7fd93
fix mypy typing issues
edgarriba Jun 14, 2021
7b8c442
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
78177bb
implement add_to_queue, get_from_queue
edgarriba Jun 14, 2021
2cbb0a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
572c2c1
fix with final api decision
edgarriba Jun 15, 2021
403b0ce
fix with final api decision
edgarriba Jun 15, 2021
8260e91
remove tests
edgarriba Jun 15, 2021
f1eba0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2021
80b8714
remove old code
edgarriba Jun 15, 2021
8805d3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2021
c9e3ec8
Merge branch 'master' into edgar/feat/spawn_args
edgarriba Jun 16, 2021
788a425
Apply suggestions from code review
edgarriba Jun 16, 2021
fe6fd3a
remove setters, implement update dicts
edgarriba Jun 16, 2021
b016585
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 16, 2021
0cd331f
undo typing
edgarriba Jun 16, 2021
90ff74e
add small test for add/get queue
edgarriba Jun 16, 2021
5c06cf2
update ddp test for 2 gpus
edgarriba Jun 16, 2021
8ae14e6
add missing docs
edgarriba Jun 17, 2021
4aa8d9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
6c219e1
add fix to changelog
edgarriba Jun 17, 2021
c441b52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
6a6ca3b
Update pytorch_lightning/core/lightning.py
edgarriba Jun 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1444,3 +1444,15 @@ on_after_batch_transfer

.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer
:noindex:

add_to_queue
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.add_to_queue
:noindex:

get_from_queue
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_from_queue
:noindex:
34 changes: 31 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import ScriptModule, Tensor
from torch.nn import Module
Expand Down Expand Up @@ -387,11 +388,11 @@ def log_dict(
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
tbptt_reduce_fx: Optional[Any] = None, # noqa: Remove in 1.6
tbptt_pad_token: Optional[Any] = None, # noqa: Remove in 1.6
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_op: Optional = None, # noqa: Remove in 1.6
sync_dist_op: Optional[Any] = None, # noqa: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
) -> None:
Expand Down Expand Up @@ -1948,3 +1949,30 @@ def model_size(self) -> float:
size_mb = os.path.getsize(tmp_name) / 1e6
os.remove(tmp_name)
return size_mb

def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue.

To avoid issues with memory sharing, we cast the data to numpy.

Args:
queue: the instance of the queue to append the data.
"""
callback_metrics: dict = apply_to_collection(
self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)

def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue.

To preserve consistency, we cast back the data to ``torch.Tensor``.

Args:
queue: the instance of the queue from where to get the data.
"""
# NOTE: this must be called in the right order to get the `callback_metrics`
callback_metrics: dict = queue.get()
self.trainer.callback_metrics.update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have to update the callback metrics here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll answer for Edgar:

The purpose of this PR was to provide a mechanism for users to add items to consume from callbacks in the spawn environment.

Hence why we update callback metrics here. Callbacks read metrics off that dictionary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca thanks so much!!

apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))
)
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, List, Tuple
from typing import Any, cast, Generator, List, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -96,7 +96,7 @@ def connect(
incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
`lr_schedulers`.
"""
model = model.to(dtype=torch.float64)
model = cast(LightningModule, model.to(dtype=torch.float64))
model = LightningDoublePrecisionModule(model)

return super().connect(model, optimizers, lr_schedulers)
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def post_dispatch(self):
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
self._results = self.mp_queue.get()
# get the `callback_metrics` and set it to the trainer
# only in case the user does not override it.
self.lightning_module.get_from_queue(self.mp_queue)

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(best_path, last_path)
Expand Down Expand Up @@ -290,6 +293,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue

def __recover_child_process_weights(self, best_path, last_path):
# transfer back the best path to the trainer
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue

def save(self, state_dict: Dict, path: str) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edgarriba is there a reason you add_to_q in tpu_spawn, but dont get_from_q?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah nevermind, I think its just bc tpu_spawn doesnt override post_dispatch

xm.save(state_dict, path)
Expand Down
41 changes: 40 additions & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
from tests.helpers.boring_model import BoringModel
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf


Expand All @@ -26,6 +26,26 @@ def on_train_start(self) -> None:
assert self.device == torch.device("cpu")


class BoringCallbackDDPSpawnModel(BoringModel):

def __init__(self, name: str, val: float):
super().__init__()
self.name = name
self.val = val

def validation_step(self, batch, batch_idx):
self.log(self.name, self.val)
return super().validation_step(batch, batch_idx)

def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
queue.put("test_val")
return super().add_to_queue(queue)

def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
self.test_val = queue.get()
return super().get_from_queue(queue)


@RunIf(skip_windows=True)
def test_ddp_cpu():
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
Expand All @@ -40,3 +60,22 @@ def test_ddp_cpu():
model = BoringModelDDPCPU()

trainer.fit(model)


@RunIf(min_gpus=2)
def test_ddp_spawn_extra_parameters(tmpdir):
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn")

assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
assert trainer.training_type_plugin.on_gpu
assert trainer.training_type_plugin.root_device == torch.device("cuda:0")

val: float = 1.0
val_name: str = "val_acc"
model = BoringCallbackDDPSpawnModel(val_name, val)
dm = BoringDataModule()

trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"