Skip to content

Commit 9d136a9

Browse files
authored
Lightning Lite core and tests (#10175)
1 parent b4f43b1 commit 9d136a9

File tree

13 files changed

+1398
-11
lines changed

13 files changed

+1398
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
219219
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))
220220
* Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164))
221221
* Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162))
222+
* Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175))
222223

223224

224225
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))

pytorch_lightning/lite/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pytorch_lightning.lite.lite import LightningLite
16+
17+
__all__ = ["LightningLite"]

pytorch_lightning/lite/lite.py

Lines changed: 501 additions & 0 deletions
Large diffs are not rendered by default.

pytorch_lightning/lite/wrappers.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Callable, Generator, Iterator, Optional, Union
15+
16+
import torch
17+
from torch import nn as nn
18+
from torch import Tensor
19+
from torch.optim import Optimizer
20+
from torch.utils.data import DataLoader
21+
22+
from pytorch_lightning.accelerators import Accelerator
23+
from pytorch_lightning.plugins import PrecisionPlugin
24+
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
25+
26+
27+
def _do_nothing_closure() -> None:
28+
return None
29+
30+
31+
class _LiteOptimizer:
32+
def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
33+
"""LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
34+
step calls to the accelerator/strategy plugin.
35+
36+
The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.
37+
38+
Args:
39+
optimizer: The optimizer to wrap
40+
accelerator: Reference to the accelerator for handling the optimizer step
41+
"""
42+
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
43+
# not want to call on destruction of the `_LiteOptimizer
44+
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
45+
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
46+
self._optimizer = optimizer
47+
self._accelerator = accelerator
48+
49+
@property
50+
def optimizer(self) -> Optimizer:
51+
return self._optimizer
52+
53+
def step(self, closure: Optional[Callable] = None) -> None:
54+
closure = closure or _do_nothing_closure
55+
self._accelerator.optimizer_step(
56+
self.optimizer,
57+
opt_idx=0,
58+
closure=closure,
59+
model=self._accelerator.model,
60+
)
61+
62+
63+
class _LiteModule(nn.Module):
64+
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
65+
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
66+
automatically for the forward pass.
67+
68+
The underlying wrapped module can be accessed via the property :attr:`module`.
69+
70+
Args:
71+
module: The module to wrap
72+
precision_plugin: Reference to the precision plugin for handling precision context
73+
"""
74+
super().__init__()
75+
self._module = module
76+
self._precision_plugin = precision_plugin
77+
78+
@property
79+
def module(self) -> nn.Module:
80+
return self._module
81+
82+
def forward(self, *args: Any, **kwargs: Any) -> Any:
83+
"""Casts all inputs to the right precision and handles autocast for operations in the module forward
84+
method."""
85+
precision = self._precision_plugin.precision
86+
precision_to_type = {
87+
"bf16": torch.bfloat16,
88+
16: torch.float16,
89+
32: torch.float32,
90+
64: torch.float64,
91+
}
92+
# TODO (@awaelchli): let the precision plugin handle the conversion
93+
to_type = precision_to_type[precision]
94+
args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor)
95+
96+
with self._precision_plugin.forward_context():
97+
output = self.module(*args, **kwargs)
98+
99+
output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor)
100+
return output
101+
102+
103+
class _LiteDataLoader(DataLoader):
104+
def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None:
105+
"""The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds
106+
additional features such as moving the data to the device automatically.
107+
108+
Args:
109+
device: The device to which the data should be moved. By default the device is `None` and no data
110+
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
111+
**dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts.
112+
"""
113+
super().__init__(**dl_kwargs)
114+
self._device = device
115+
116+
@property
117+
def device(self) -> Optional[torch.device]:
118+
return self._device
119+
120+
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
121+
iterator = super().__iter__()
122+
if self._device is None:
123+
return iterator
124+
125+
for item in iterator:
126+
yield move_data_to_device(item, self._device)

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ def precision(self) -> Union[str, int]:
336336

337337
@property
338338
def amp_level(self) -> Optional[str]:
339-
return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level
339+
if self._amp_type == AMPType.APEX:
340+
return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level
340341

341342
@property
342343
def amp_type(self) -> Optional[str]:

pytorch_lightning/trainer/data_loading.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
114114
" in the `DataLoader` init to improve performance."
115115
)
116116

117-
def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
117+
@staticmethod
118+
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
118119
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
119-
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
120+
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
120121

121122
def _requires_distributed_sampler(self, dataloader) -> bool:
122123
return (
@@ -336,7 +337,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
336337
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader")
337338

338339
# add worker_init_fn for correct seeding in worker processes
339-
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
340+
apply_to_collection(self.train_dataloader, DataLoader, self._auto_add_worker_init_fn, rank=self.global_rank)
340341

341342
# add collate_fn to collect metadata for fault tolerant training
342343
if _fault_tolerant_training():
@@ -443,7 +444,9 @@ def _reset_eval_dataloader(
443444
dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]
444445

445446
# add worker_init_fn for correct seeding in worker processes
446-
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
447+
apply_to_collection(
448+
dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn, rank=self.global_rank
449+
)
447450

448451
loader_num_batches = []
449452

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def __init__(
559559
if gradient_clip_algorithm is not None
560560
else gradient_clip_algorithm
561561
)
562-
self.track_grad_norm = float(track_grad_norm)
562+
self.track_grad_norm: float = float(track_grad_norm)
563563

564564
self._detect_anomaly: bool = detect_anomaly
565565
self._setup_on_init(num_sanity_val_steps)

tests/helpers/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def run_model_test(
6767
assert trainer.state.finished, f"Training failed with {trainer.state}"
6868
# Check that the model is actually changed post-training
6969
change_ratio = torch.norm(initial_values - post_train_values)
70-
assert change_ratio > 0.1, f"the model is changed of {change_ratio}"
70+
assert change_ratio > 0.03, f"the model is changed of {change_ratio}"
7171

7272
# test model loading
7373
pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))

tests/lite/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)