Skip to content

Fix propagation of device and dtype properties in Lite modules #10559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))


-
- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#10559](https://github.com/PyTorchLightning/pytorch-lightning/issues/10559))


-
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device

Expand Down Expand Up @@ -64,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None:
)


class _LiteModule(nn.Module):
class _LiteModule(DeviceDtypeModuleMixin):
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
Expand Down
22 changes: 22 additions & 0 deletions tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -65,6 +66,27 @@ def check_autocast(forward_input):
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()


@pytest.mark.parametrize(
"device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))]
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_lite_module_device_dtype_propagation(device, dtype):
"""Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""

class DeviceModule(DeviceDtypeModuleMixin):
pass

device_module = DeviceModule()
lite_module = _LiteModule(device_module, Mock())
lite_module.to(device)
assert device_module.device == device
assert lite_module.device == device

lite_module.to(dtype)
assert device_module.dtype == dtype
assert lite_module.dtype == dtype


def test_lite_dataloader_iterator():
"""Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic
device placement)."""
Expand Down