Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Commit 7ab481d

Browse files
awaelchliRaalsky
authored andcommitted
Fix propagation of device and dtype properties in Lite modules (Lightning-AI#10559)
1 parent 9a0237a commit 7ab481d

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
163163
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))
164164

165165

166-
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
166+
- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#10559](https://github.com/PyTorchLightning/pytorch-lightning/issues/10559))
167167

168168

169-
-
169+
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
170170

171171
## [1.5.1] - 2021-11-09
172172

pytorch_lightning/lite/wrappers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.utils.data import DataLoader
2525

2626
from pytorch_lightning.accelerators import Accelerator
27+
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
2728
from pytorch_lightning.plugins import PrecisionPlugin
2829
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2930

@@ -64,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None:
6465
)
6566

6667

67-
class _LiteModule(nn.Module):
68+
class _LiteModule(DeviceDtypeModuleMixin):
6869
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
6970
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
7071
automatically for the forward pass.

tests/lite/test_wrappers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from torch.utils.data.dataloader import DataLoader
1919

20+
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
2021
from pytorch_lightning.lite import LightningLite
2122
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
2223
from tests.helpers.runif import RunIf
@@ -65,6 +66,27 @@ def check_autocast(forward_input):
6566
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
6667

6768

69+
@pytest.mark.parametrize(
70+
"device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))]
71+
)
72+
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
73+
def test_lite_module_device_dtype_propagation(device, dtype):
74+
"""Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""
75+
76+
class DeviceModule(DeviceDtypeModuleMixin):
77+
pass
78+
79+
device_module = DeviceModule()
80+
lite_module = _LiteModule(device_module, Mock())
81+
lite_module.to(device)
82+
assert device_module.device == device
83+
assert lite_module.device == device
84+
85+
lite_module.to(dtype)
86+
assert device_module.dtype == dtype
87+
assert lite_module.dtype == dtype
88+
89+
6890
def test_lite_dataloader_iterator():
6991
"""Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic
7092
device placement)."""

0 commit comments

Comments
 (0)