diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a46427fe623c..0b431b6661ab5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -257,6 +257,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931)) +- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916)) + + ## [1.3.5] - 2021-06-08 ### Added diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 824fa7a2513f2..f72e85577f3ae 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1444,3 +1444,15 @@ on_after_batch_transfer .. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer :noindex: + +add_to_queue +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.add_to_queue + :noindex: + +get_from_queue +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_from_queue + :noindex: diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bc070b25e7b4e..c5c83880d6249 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -27,6 +27,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +import numpy as np import torch from torch import ScriptModule, Tensor from torch.nn import Module @@ -387,11 +388,11 @@ def log_dict( on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, reduce_fx: Union[str, Callable] = 'default', # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6 - tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6 - tbptt_pad_token: Optional = None, # noqa: Remove in 1.6 + tbptt_reduce_fx: Optional[Any] = None, # noqa: Remove in 1.6 + tbptt_pad_token: Optional[Any] = None, # noqa: Remove in 1.6 enable_graph: bool = False, sync_dist: bool = False, - sync_dist_op: Optional = None, # noqa: Remove in 1.6 + sync_dist_op: Optional[Any] = None, # noqa: Remove in 1.6 sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, ) -> None: @@ -1948,3 +1949,30 @@ def model_size(self) -> float: size_mb = os.path.getsize(tmp_name) / 1e6 os.remove(tmp_name) return size_mb + + def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. + + To avoid issues with memory sharing, we cast the data to numpy. + + Args: + queue: the instance of the queue to append the data. + """ + callback_metrics: dict = apply_to_collection( + self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() + ) # send as numpy to avoid issues with memory sharing + queue.put(callback_metrics) + + def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. + + To preserve consistency, we cast back the data to ``torch.Tensor``. + + Args: + queue: the instance of the queue from where to get the data. + """ + # NOTE: `add_to_queue` needs to be called before + callback_metrics: dict = queue.get() + self.trainer.callback_metrics.update( + apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)) + ) diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index a8af715d9849b..e0ecddf322250 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Generator, List, Tuple +from typing import Any, cast, Generator, List, Tuple import torch import torch.nn as nn @@ -96,7 +96,7 @@ def connect( incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or `lr_schedulers`. """ - model = model.to(dtype=torch.float64) + model = cast(LightningModule, model.to(dtype=torch.float64)) model = LightningDoublePrecisionModule(model) return super().connect(model, optimizers, lr_schedulers) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8d2cc217835fb..01c389353a4a0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -214,6 +214,9 @@ def post_dispatch(self): best_path = self.mp_queue.get() last_path = self.mp_queue.get() self._results = self.mp_queue.get() + # get the `callback_metrics` and set it to the trainer + # only in case the user does not override it. + self.lightning_module.get_from_queue(self.mp_queue) # recover the weights of the processes trained in the children self.__recover_child_process_weights(best_path, last_path) @@ -290,6 +293,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results) + self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue def __recover_child_process_weights(self, best_path, last_path): # transfer back the best path to the trainer diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9921fadd2cfc1..1fd9bbcb0a2cf 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -202,6 +202,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results) + self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue def save(self, state_dict: Dict, path: str) -> None: xm.save(state_dict, path) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 8afc30c4692ec..26a7746c41cfe 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -15,7 +15,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPSpawnPlugin -from tests.helpers.boring_model import BoringModel +from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -26,6 +26,26 @@ def on_train_start(self) -> None: assert self.device == torch.device("cpu") +class BoringCallbackDDPSpawnModel(BoringModel): + + def __init__(self, name: str, val: float): + super().__init__() + self.name = name + self.val = val + + def validation_step(self, batch, batch_idx): + self.log(self.name, self.val) + return super().validation_step(batch, batch_idx) + + def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + queue.put("test_val") + return super().add_to_queue(queue) + + def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + self.test_val = queue.get() + return super().get_from_queue(queue) + + @RunIf(skip_windows=True) def test_ddp_cpu(): """Tests if device is set correctely when training for DDPSpawnPlugin.""" @@ -40,3 +60,22 @@ def test_ddp_cpu(): model = BoringModelDDPCPU() trainer.fit(model) + + +@RunIf(min_gpus=2) +def test_ddp_spawn_extra_parameters(tmpdir): + """Tests if device is set correctely when training for DDPSpawnPlugin.""" + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp_spawn") + + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + assert trainer.training_type_plugin.on_gpu + assert trainer.training_type_plugin.root_device == torch.device("cuda:0") + + val: float = 1.0 + val_name: str = "val_acc" + model = BoringCallbackDDPSpawnModel(val_name, val) + dm = BoringDataModule() + + trainer.fit(model, datamodule=dm) + assert trainer.callback_metrics[val_name] == torch.tensor(val) + assert model.test_val == "test_val"