Skip to content

Commit d7404c7

Browse files
awaelchlicarmoccaBordajustusschock
authored
Integration tests for Precision in Lite (#14815)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Justus Schock <[email protected]>
1 parent 8c01e82 commit d7404c7

File tree

8 files changed

+255
-7
lines changed

8 files changed

+255
-7
lines changed

src/lightning_lite/plugins/precision/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DeepSpeedPrecision(Precision):
3030
"""Precision plugin for DeepSpeed integration.
3131
3232
Args:
33-
precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16).
33+
precision: Full precision (32), half precision (16) or bfloat16 precision (bf16).
3434
amp_type: The mixed precision backend to use ("native" or "apex").
3535
amp_level: The optimization level to use (O1, O2, etc...). By default it will be set to "O2"
3636
if ``amp_type`` is set to "apex".

tests/tests_lite/helpers/models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any, Iterator
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch import Tensor
6+
from torch.nn import Module
7+
from torch.optim import Optimizer
8+
from torch.utils.data import DataLoader, Dataset, IterableDataset
9+
10+
from lightning_lite import LightningLite
11+
12+
13+
class RandomDataset(Dataset):
14+
def __init__(self, size: int, length: int) -> None:
15+
self.len = length
16+
self.data = torch.randn(length, size)
17+
18+
def __getitem__(self, index: int) -> Tensor:
19+
return self.data[index]
20+
21+
def __len__(self) -> int:
22+
return self.len
23+
24+
25+
class RandomIterableDataset(IterableDataset):
26+
def __init__(self, size: int, count: int) -> None:
27+
self.count = count
28+
self.size = size
29+
30+
def __iter__(self) -> Iterator[Tensor]:
31+
for _ in range(self.count):
32+
yield torch.randn(self.size)
33+
34+
35+
class BoringLite(LightningLite):
36+
def get_model(self) -> Module:
37+
return nn.Linear(32, 2)
38+
39+
def get_dataloader(self) -> DataLoader:
40+
return DataLoader(RandomDataset(32, 64))
41+
42+
def step(self, model: Module, batch: Any) -> Tensor:
43+
output = model(batch)
44+
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
45+
return loss
46+
47+
def after_backward(self, model: Module) -> None:
48+
pass
49+
50+
def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None:
51+
pass
52+
53+
def run(self) -> None:
54+
model = self.get_model()
55+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
56+
dataloader = self.get_dataloader()
57+
58+
model, optimizer = self.setup(model, optimizer)
59+
dataloader = self.setup_dataloaders(dataloader)
60+
61+
data_iter = iter(dataloader)
62+
batch = next(data_iter)
63+
loss = self.step(model, batch)
64+
self.backward(loss)
65+
self.after_backward(model)
66+
optimizer.step()
67+
self.after_optimizer_step(model, optimizer)
68+
optimizer.zero_grad()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Integration tests for double-precision training."""
15+
16+
import torch
17+
import torch.nn as nn
18+
from tests_lite.helpers.models import BoringLite
19+
20+
21+
class BoringDoubleModule(nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.layer = torch.nn.Linear(32, 2)
25+
self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False)
26+
27+
def forward(self, x):
28+
assert x.dtype == torch.float64
29+
# the default dtype for new tensors is now float64
30+
assert torch.tensor([0.0]).dtype == torch.float64
31+
return self.layer(x)
32+
33+
34+
class DoublePrecisionBoringLite(BoringLite):
35+
def get_model(self):
36+
return BoringDoubleModule()
37+
38+
def step(self, model, batch):
39+
model.double() # TODO(lite): this needs to be done automatically in Lite.setup()
40+
assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float64
41+
assert model.complex_buffer.dtype == torch.complex64
42+
43+
assert batch.dtype == torch.float32
44+
output = model(batch)
45+
assert output.dtype == torch.float32
46+
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
47+
return loss
48+
49+
def after_backward(self, model):
50+
assert model.layer.weight.grad.dtype == torch.float64
51+
52+
53+
def test_double_precision(tmpdir):
54+
lite = DoublePrecisionBoringLite(precision=64)
55+
lite.run()

tests/tests_lite/plugins/precision/test_native_amp.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,25 @@ def test_native_amp_precision_bf16_min_torch():
4343

4444
@RunIf(min_torch="1.10")
4545
def test_native_amp_precision_forward_context():
46-
precision_plugin = NativeMixedPrecision(precision="mixed", device="cuda")
46+
"""Test to ensure that the context manager correctly is set to CPU + bfloat16."""
47+
precision_plugin = NativeMixedPrecision(precision=16, device="cuda")
48+
assert precision_plugin.device == "cuda"
49+
assert isinstance(precision_plugin.scaler, torch.cuda.amp.GradScaler)
4750
assert torch.get_default_dtype() == torch.float32
4851
with precision_plugin.forward_context():
4952
assert torch.get_autocast_gpu_dtype() == torch.float16
5053

54+
precision_plugin = NativeMixedPrecision(precision="bf16", device="cpu")
55+
assert precision_plugin.device == "cpu"
56+
assert precision_plugin.scaler is None
57+
with precision_plugin.forward_context():
58+
assert torch.get_autocast_cpu_dtype() == torch.bfloat16
59+
60+
context_manager = precision_plugin._autocast_context_manager()
61+
assert isinstance(context_manager, torch.autocast)
62+
# check with str due to a bug upstream: https://github.com/pytorch/pytorch/issues/65786
63+
assert str(context_manager.fast_dtype) == str(torch.bfloat16)
64+
5165

5266
def test_native_amp_precision_backward():
5367
precision_plugin = NativeMixedPrecision(precision="mixed", device="cuda")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Integration tests for native automatic mixed precision (AMP) training."""
15+
import pytest
16+
import torch
17+
import torch.nn as nn
18+
from tests_lite.helpers.models import BoringLite
19+
from tests_lite.helpers.runif import RunIf
20+
21+
22+
class NativeMixedPrecisionModule(nn.Module):
23+
def __init__(self, expected_dtype):
24+
super().__init__()
25+
self.expected_dtype = expected_dtype
26+
self.layer = torch.nn.Linear(32, 2)
27+
28+
def forward(self, x):
29+
assert x.dtype == self.expected_dtype
30+
if x.device.type == "cpu":
31+
assert torch.is_autocast_cpu_enabled()
32+
else:
33+
assert torch.is_autocast_enabled()
34+
output = self.layer(x)
35+
assert output.dtype == self.expected_dtype
36+
return output
37+
38+
39+
class NativeMixedPrecisionBoringLite(BoringLite):
40+
41+
expected_dtype: torch.dtype
42+
43+
def get_model(self):
44+
return NativeMixedPrecisionModule(self.expected_dtype)
45+
46+
def step(self, model, batch):
47+
assert model.layer.weight.dtype == torch.float32
48+
49+
assert batch.dtype == torch.float32
50+
output = model(batch)
51+
assert output.dtype == torch.float32
52+
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
53+
return loss
54+
55+
def after_backward(self, model):
56+
assert model.layer.weight.grad.dtype == torch.float32
57+
58+
59+
@RunIf(min_torch="1.10")
60+
@pytest.mark.parametrize(
61+
"accelerator, precision, expected_dtype",
62+
[
63+
("cpu", 16, torch.bfloat16),
64+
("cpu", "bf16", torch.bfloat16),
65+
pytest.param("cuda", 16, torch.float16, marks=RunIf(min_cuda_gpus=1)),
66+
pytest.param("cuda", "bf16", torch.bfloat16, marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)),
67+
],
68+
)
69+
def test_native_mixed_precision(accelerator, precision, expected_dtype):
70+
lite = NativeMixedPrecisionBoringLite(accelerator=accelerator, precision=16)
71+
lite.expected_dtype = expected_dtype
72+
lite.run()

tests/tests_lite/test_connector.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from lightning_lite.accelerators.cuda import CUDAAccelerator
2828
from lightning_lite.accelerators.mps import MPSAccelerator
2929
from lightning_lite.connector import _Connector
30-
from lightning_lite.plugins import DoublePrecision, Precision
30+
from lightning_lite.plugins import DoublePrecision, NativeMixedPrecision, Precision
3131
from lightning_lite.plugins.environments import (
3232
KubeflowEnvironment,
3333
LightningEnvironment,
@@ -692,3 +692,44 @@ def test_gpu_accelerator_no_gpu_backend_found_error(*_):
692692
def test_ddp_fork_on_unsupported_platform(_, strategy):
693693
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
694694
_Connector(strategy=strategy)
695+
696+
697+
@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", True)
698+
def test_precision_selection_16_on_cpu_warns():
699+
with pytest.warns(
700+
UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16"
701+
):
702+
_Connector(precision=16)
703+
704+
705+
@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", False)
706+
def test_precision_selection_16_raises_torch_version(monkeypatch):
707+
with pytest.raises(ImportError, match="must install torch greater or equal to 1.10"):
708+
_Connector(accelerator="cpu", precision=16)
709+
with pytest.raises(ImportError, match="must install torch greater or equal to 1.10"):
710+
_Connector(accelerator="cpu", precision="bf16")
711+
712+
713+
class MyNativeAMP(NativeMixedPrecision):
714+
pass
715+
716+
717+
@RunIf(mps=False)
718+
@pytest.mark.parametrize("strategy,devices", [("ddp", 2), ("ddp_spawn", 2)])
719+
@pytest.mark.parametrize(
720+
"is_custom_plugin,plugin_cls",
721+
[(False, NativeMixedPrecision), (True, MyNativeAMP)],
722+
)
723+
@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", True)
724+
def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin_cls):
725+
plugin = None
726+
if is_custom_plugin:
727+
plugin = plugin_cls(16, "cpu")
728+
729+
trainer = _Connector(
730+
precision=16,
731+
devices=devices,
732+
strategy=strategy,
733+
plugins=plugin,
734+
)
735+
assert isinstance(trainer.precision_plugin, plugin_cls)

tests/tests_lite/test_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.multiprocessing as mp
2424
import torch.nn.functional
2525
from lightning_utilities.core.apply_func import apply_to_collection
26+
from tests_lite.helpers.models import RandomDataset
2627
from tests_lite.helpers.runif import RunIf
2728
from torch import nn
2829
from torch.nn.parallel.distributed import DistributedDataParallel
@@ -34,7 +35,6 @@
3435
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
3536
from lightning_lite.utilities.apply_func import move_data_to_device
3637
from lightning_lite.utilities.cloud_io import atomic_save
37-
from pytorch_lightning.demos.boring_classes import RandomDataset
3838

3939

4040
class BoringModel(nn.Module):

tests/tests_lite/utilities/test_data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
import torch
5+
from tests_lite.helpers.models import RandomDataset, RandomIterableDataset
56
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
67

78
from lightning_lite.utilities.data import (
@@ -16,9 +17,6 @@
1617
)
1718
from lightning_lite.utilities.exceptions import MisconfigurationException
1819

19-
# TODO(lite): provide boring classes in Lite
20-
from pytorch_lightning.demos.boring_classes import RandomDataset, RandomIterableDataset
21-
2220

2321
def test_has_iterable_dataset():
2422
assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))

0 commit comments

Comments
 (0)