From a2337e8c8bb0629dccb94ab60dd7f6b53dd1c5b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:39:12 +0100 Subject: [PATCH 01/10] add fix and test --- pytorch_lightning/lite/wrappers.py | 3 ++- tests/lite/test_wrappers.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 615f461055204..0432c4b9c2274 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -24,6 +24,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -64,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None: ) -class _LiteModule(nn.Module): +class _LiteModule(DeviceDtypeModuleMixin, nn.Module): def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None: """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 4993a10c8dbc2..65789739ae5ad 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -15,8 +15,10 @@ import pytest import torch +import torch.nn as nn from torch.utils.data.dataloader import DataLoader +from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from tests.helpers.runif import RunIf @@ -65,6 +67,18 @@ def check_autocast(forward_input): assert out.dtype == input_type or out.dtype == torch.get_default_dtype() +@RunIf(min_gpus=1) +@pytest.mark.parametrize("device", ["cpu", torch.device("cuda", 0)]) +def test_lite_module_device_propagation(device): + class DeviceModule(DeviceDtypeModuleMixin): + pass + + module = DeviceModule() + lite_module = _LiteModule(module, Mock()) + lite_module.to(device) + assert lite_module.device is lite_module.module.device is module.device is device + + def test_lite_dataloader_iterator(): """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic device placement).""" From 8a91b5813f0aeb5e39fd1dc102714dfc572a76f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:40:55 +0100 Subject: [PATCH 02/10] eq --- tests/lite/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 65789739ae5ad..3c051186612eb 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -76,7 +76,7 @@ class DeviceModule(DeviceDtypeModuleMixin): module = DeviceModule() lite_module = _LiteModule(module, Mock()) lite_module.to(device) - assert lite_module.device is lite_module.module.device is module.device is device + assert lite_module.device == lite_module.module.device == module.device == device def test_lite_dataloader_iterator(): From bbabb4466d07ee6422383e8d20f2f6bf21c654be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:41:54 +0100 Subject: [PATCH 03/10] device --- tests/lite/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 3c051186612eb..73ab69647234c 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -68,7 +68,7 @@ def check_autocast(forward_input): @RunIf(min_gpus=1) -@pytest.mark.parametrize("device", ["cpu", torch.device("cuda", 0)]) +@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda", 0)]) def test_lite_module_device_propagation(device): class DeviceModule(DeviceDtypeModuleMixin): pass From 917d8637f1770a3a5fee6ec17a066a86b3aea415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:42:25 +0100 Subject: [PATCH 04/10] verify no fix --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 0432c4b9c2274..bfa10c1b65c3c 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -65,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None: ) -class _LiteModule(DeviceDtypeModuleMixin, nn.Module): +class _LiteModule(nn.Module): def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None: """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. From 7bdc2f114aba4822ca0723a8422c96c998659f2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:44:13 +0100 Subject: [PATCH 05/10] update assertion --- tests/lite/test_wrappers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 73ab69647234c..2c9601dfcc8bb 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -73,10 +73,11 @@ def test_lite_module_device_propagation(device): class DeviceModule(DeviceDtypeModuleMixin): pass - module = DeviceModule() - lite_module = _LiteModule(module, Mock()) + device_module = DeviceModule() + lite_module = _LiteModule(device_module, Mock()) lite_module.to(device) - assert lite_module.device == lite_module.module.device == module.device == device + assert device_module.device == device + assert lite_module.device == device def test_lite_dataloader_iterator(): From 1deb69a7978ed1aa6706a862f36775d2734a9b52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:45:01 +0100 Subject: [PATCH 06/10] reapply fix --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index bfa10c1b65c3c..ff95e89d1d2cf 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -65,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None: ) -class _LiteModule(nn.Module): +class _LiteModule(DeviceDtypeModuleMixin): def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None: """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. From 59ac447872cfba1c737c12a790da8e8a8ecade07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:52:43 +0100 Subject: [PATCH 07/10] expand test for dtype --- tests/lite/test_wrappers.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 2c9601dfcc8bb..ea7d825c5a5bd 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -67,9 +67,13 @@ def check_autocast(forward_input): assert out.dtype == input_type or out.dtype == torch.get_default_dtype() -@RunIf(min_gpus=1) -@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda", 0)]) -def test_lite_module_device_propagation(device): +@pytest.mark.parametrize( + "device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))] +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_lite_module_device_dtype_propagation(device, dtype): + """Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics).""" + class DeviceModule(DeviceDtypeModuleMixin): pass @@ -79,6 +83,10 @@ class DeviceModule(DeviceDtypeModuleMixin): assert device_module.device == device assert lite_module.device == device + lite_module.to(dtype) + assert device_module.dtype == dtype + assert lite_module.dtype == dtype + def test_lite_dataloader_iterator(): """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic From 927355b0b7a4585b9cdcd21936b89d5a3df2e140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 13:55:33 +0100 Subject: [PATCH 08/10] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 25afcc58e54f1..8fec8aa0363c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,7 +163,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) -- +- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#](https://github.com/PyTorchLightning/pytorch-lightning/issues/)) - From 52050111ba5fc122868110a5fe913b72a6afbfe7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 14:11:02 +0100 Subject: [PATCH 09/10] add number --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fec8aa0363c8..881daf3308536 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,7 +163,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) -- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#](https://github.com/PyTorchLightning/pytorch-lightning/issues/)) +- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#10559](https://github.com/PyTorchLightning/pytorch-lightning/issues/10559)) - From b021310655eec98d13e53a215f694c71b8fe619b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Nov 2021 14:27:20 +0100 Subject: [PATCH 10/10] remove unusd import --- tests/lite/test_wrappers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index ea7d825c5a5bd..c271d3b3163ed 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -15,7 +15,6 @@ import pytest import torch -import torch.nn as nn from torch.utils.data.dataloader import DataLoader from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin