Skip to content

Commit 5262b63

Browse files
carmoccatchaton
andauthored
Pass the scaler as an input to NativeMixedPrecisionPlugin (#10055)
Co-authored-by: thomas chaton <[email protected]>
1 parent 83d74bb commit 5262b63

File tree

8 files changed

+90
-85
lines changed

8 files changed

+90
-85
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
344344
- Moved the `optimizer_step` and `clip_gradients` hook from the `Accelerator` and `TrainingTypePlugin` into the `PrecisionPlugin` ([#10143](https://github.com/PyTorchLightning/pytorch-lightning/pull/10143), [#10029](https://github.com/PyTorchLightning/pytorch-lightning/pull/10029))
345345

346346

347+
- `NativeMixedPrecisionPlugin` and its subclasses now take an optional `GradScaler` instance ([#10055](https://github.com/PyTorchLightning/pytorch-lightning/pull/10055))
348+
349+
347350
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
348351

349352

docs/source/extensions/accelerators.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ One to handle differences from the training routine and one to handle different
2626
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
2727

2828
accelerator = GPUAccelerator(
29-
precision_plugin=NativeMixedPrecisionPlugin(),
29+
precision_plugin=NativeMixedPrecisionPlugin(16, "cuda"),
3030
training_type_plugin=DDPPlugin(),
3131
)
3232
trainer = Trainer(accelerator=accelerator)

pytorch_lightning/plugins/precision/fully_sharded_native_amp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919

2020
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
21-
"""Mixed Precision for Full Sharded Training."""
22-
23-
precision = "mixed"
21+
"""Native AMP for Fully Sharded Training."""
2422

2523
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
2624
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Any, Callable, Dict, Generator, Union
15+
from typing import Any, Callable, Dict, Generator, Optional, Union
1616

1717
import torch
1818
from torch import Tensor
@@ -31,41 +31,39 @@
3131

3232

3333
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
34-
"""Plugin for native mixed precision training with :mod:`torch.cuda.amp`.
34+
"""Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``.
3535
3636
Args:
37-
precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16).
37+
precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``).
38+
device: The device for ``torch.autocast``.
39+
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
3840
"""
3941

40-
def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
41-
super().__init__()
42-
self.use_cpu = use_cpu
43-
self._dtype = self._select_precision_dtype(precision)
44-
self.backend = AMPType.NATIVE
45-
if not self.is_bfloat16:
46-
self.scaler = torch.cuda.amp.GradScaler()
47-
48-
def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
49-
if precision == "bf16":
50-
if not _TORCH_GREATER_EQUAL_1_10:
51-
raise MisconfigurationException(
52-
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
53-
)
54-
return torch.bfloat16
55-
return torch.float16
42+
backend = AMPType.NATIVE
5643

57-
@property
58-
def is_bfloat16(self) -> bool:
59-
return self._dtype == torch.bfloat16
44+
def __init__(
45+
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
46+
) -> None:
47+
super().__init__()
48+
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
49+
raise MisconfigurationException(
50+
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
51+
)
52+
if scaler is None and precision == 16:
53+
scaler = torch.cuda.amp.GradScaler()
54+
if scaler is not None and precision == "bf16":
55+
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
56+
self.precision = precision
57+
self.device = device
58+
self.scaler = scaler
6059

6160
def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor:
62-
if self.is_bfloat16:
63-
return super().pre_backward(model, closure_loss)
64-
closure_loss = self.scaler.scale(closure_loss)
61+
if self.scaler is not None:
62+
closure_loss = self.scaler.scale(closure_loss)
6563
return super().pre_backward(model, closure_loss)
6664

6765
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
68-
if not self.is_bfloat16:
66+
if self.scaler is not None:
6967
tensor = self.scaler.scale(tensor)
7068
super()._run_backward(tensor, model, *args, **kwargs)
7169

@@ -77,7 +75,7 @@ def optimizer_step(
7775
lambda_closure: Callable[[], Any],
7876
**kwargs: Any,
7977
) -> None:
80-
if self.is_bfloat16:
78+
if self.scaler is None:
8179
# skip scaler logic, as bfloat16 does not require scaler
8280
return super().optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
8381
if isinstance(optimizer, LBFGS):
@@ -98,7 +96,9 @@ def optimizer_step(
9896

9997
def autocast_context_manager(self) -> autocast:
10098
if _TORCH_GREATER_EQUAL_1_10:
101-
return autocast("cpu" if self.use_cpu else "cuda", dtype=self._dtype)
99+
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
100+
# https://github.com/pytorch/pytorch/issues/67233
101+
return autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
102102
return autocast()
103103

104104
@contextmanager
@@ -108,9 +108,9 @@ def forward_context(self) -> Generator[None, None, None]:
108108
yield
109109

110110
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
111-
if "native_amp_scaling_state" in checkpoint and not self.is_bfloat16:
111+
if self.scaler is not None and "native_amp_scaling_state" in checkpoint:
112112
self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])
113113

114114
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
115-
if not self.is_bfloat16:
115+
if self.scaler is not None:
116116
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,31 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Union
14+
from typing import Optional, Union
15+
16+
import torch
1517

1618
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
1719
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1821

1922
if _FAIRSCALE_AVAILABLE:
2023
from fairscale.optim import OSS
2124
from fairscale.optim.grad_scaler import ShardedGradScaler
2225

2326

2427
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
25-
"""Mixed Precision for Sharded Training."""
28+
"""Native AMP for Sharded Training."""
2629

27-
def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
28-
super().__init__(precision, use_cpu=use_cpu)
29-
if not self.use_cpu:
30-
self.scaler = ShardedGradScaler()
30+
def __init__(
31+
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
32+
) -> None:
33+
if not _FAIRSCALE_AVAILABLE:
34+
raise MisconfigurationException(
35+
"You have asked for sharded AMP but you have not installed it."
36+
" Install `fairscale` using this guide: https://https://github.com/facebookresearch/fairscale"
37+
)
38+
super().__init__(precision, device, scaler=scaler or ShardedGradScaler())
3139

3240
def clip_grad_by_norm(self, optimizer: "OSS", clip_val: Union[int, float]) -> None:
3341
optimizer.clip_grad_norm(clip_val)

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -638,16 +638,27 @@ def select_precision_plugin(self) -> PrecisionPlugin:
638638
)
639639
self.precision = "bf16"
640640

641-
if self.precision == 16:
642-
rank_zero_info(f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)")
641+
if self.precision in (16, "bf16"):
642+
if self.precision == "bf16" and self.amp_type != AMPType.NATIVE:
643+
raise MisconfigurationException(
644+
f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported."
645+
" Try using `amp_type='native'` instead."
646+
)
647+
648+
rank_zero_info(
649+
f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)"
650+
if self.precision == 16
651+
else "Using bfloat16 Automatic Mixed Precision (AMP)"
652+
)
643653

644654
if self.amp_type == AMPType.NATIVE:
655+
device = "cpu" if self.use_cpu else "cuda"
656+
645657
if self._is_sharded_training_type:
646-
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
658+
return ShardedNativeMixedPrecisionPlugin(self.precision, device)
647659
if self._is_fully_sharded_training_type:
648-
return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
649-
650-
return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
660+
return FullyShardedNativeMixedPrecisionPlugin(self.precision, device)
661+
return NativeMixedPrecisionPlugin(self.precision, device)
651662

652663
if self.amp_type == AMPType.APEX:
653664
if self._is_sharded_training_type or self._is_fully_sharded_training_type:
@@ -657,19 +668,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
657668
self.amp_level = self.amp_level or "O2"
658669
return ApexMixedPrecisionPlugin(self.amp_level)
659670

660-
if self.precision == "bf16":
661-
if self.amp_type != AMPType.NATIVE:
662-
raise MisconfigurationException(
663-
"You passed `Trainer(amp_type='apex', precision='bf16')` but it's not supported."
664-
" Try using `amp_type='native'` instead."
665-
)
666-
rank_zero_info("Using bfloat16 Automatic Mixed Precision (AMP)")
667-
if self._is_sharded_training_type:
668-
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
669-
if self._is_fully_sharded_training_type:
670-
return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
671-
return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
672-
673671
raise RuntimeError("No precision set")
674672

675673
def select_training_type_plugin(self) -> TrainingTypePlugin:

tests/models/test_amp.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,41 @@
2727

2828

2929
class AMPTestModel(BoringModel):
30-
def _step(self, batch, batch_idx):
30+
def _step(self, batch):
3131
self._assert_autocast_enabled()
3232
output = self(batch)
33-
bfloat16 = self.trainer.precision_plugin.is_bfloat16
34-
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
33+
is_bfloat16 = self.trainer.precision_plugin.precision == "bf16"
34+
assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16
3535
loss = self.loss(batch, output)
3636
return loss
3737

3838
def loss(self, batch, prediction):
3939
# todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported
40-
if self.trainer.precision_plugin.use_cpu:
40+
if self.trainer.precision_plugin.device == "cpu":
4141
prediction = prediction.float()
4242
return super().loss(batch, prediction)
4343

4444
def training_step(self, batch, batch_idx):
45-
output = self._step(batch, batch_idx)
45+
output = self._step(batch)
4646
return {"loss": output}
4747

4848
def validation_step(self, batch, batch_idx):
49-
output = self._step(batch, batch_idx)
49+
output = self._step(batch)
5050
return {"x": output}
5151

5252
def test_step(self, batch, batch_idx):
53-
output = self._step(batch, batch_idx)
53+
output = self._step(batch)
5454
return {"y": output}
5555

56-
def predict(self, batch, batch_idx, dataloader_idx=None):
56+
def predict_step(self, batch, batch_idx, dataloader_idx=None):
5757
self._assert_autocast_enabled()
5858
output = self(batch)
59-
bfloat16 = self.trainer.precision_plugin.is_bfloat16
60-
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
59+
is_bfloat16 = self.trainer.precision_plugin.precision == "bf16"
60+
assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16
6161
return output
6262

6363
def _assert_autocast_enabled(self):
64-
if self.trainer.precision_plugin.use_cpu:
64+
if self.trainer.precision_plugin.device == "cpu":
6565
assert torch.is_autocast_cpu_enabled()
6666
else:
6767
assert torch.is_autocast_enabled()

tests/plugins/test_amp_plugins.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from pytorch_lightning import Trainer
2222
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
23-
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
2423
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2524
from tests.helpers import BoringModel
2625
from tests.helpers.runif import RunIf
@@ -47,7 +46,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
4746
},
4847
)
4948
@mock.patch("torch.cuda.device_count", return_value=2)
50-
@pytest.mark.parametrize("ddp_backend,gpus", [("ddp", 2), ("ddp2", 2), ("ddp_spawn", 2)])
49+
@pytest.mark.parametrize("strategy,gpus", [("ddp", 2), ("ddp2", 2), ("ddp_spawn", 2)])
5150
@pytest.mark.parametrize(
5251
"amp,custom_plugin,plugin_cls",
5352
[
@@ -57,21 +56,19 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
5756
pytest.param("apex", True, MyApexPlugin, marks=RunIf(amp_apex=True)),
5857
],
5958
)
60-
def test_amp_apex_ddp(
61-
mocked_device_count, ddp_backend: str, gpus: int, amp: str, custom_plugin: bool, plugin_cls: MixedPrecisionPlugin
62-
):
63-
59+
def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, plugin_cls):
60+
plugin = None
61+
if custom_plugin:
62+
plugin = plugin_cls(16, "cpu") if amp == "native" else plugin_cls()
6463
trainer = Trainer(
6564
fast_dev_run=True,
6665
precision=16,
6766
amp_backend=amp,
6867
gpus=gpus,
69-
strategy=ddp_backend,
70-
plugins=[plugin_cls()] if custom_plugin else None,
68+
strategy=strategy,
69+
plugins=plugin,
7170
)
7271
assert isinstance(trainer.precision_plugin, plugin_cls)
73-
if amp == "native":
74-
assert not trainer.precision_plugin.is_bfloat16
7572

7673

7774
class GradientUnscaleBoringModel(BoringModel):
@@ -179,13 +176,14 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
179176

180177
@RunIf(min_torch="1.10")
181178
def test_cpu_amp_precision_context_manager(tmpdir):
182-
"""Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""
183-
plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
184-
assert plugin.use_cpu
185-
assert not hasattr(plugin, "scaler")
179+
"""Test to ensure that the context manager correctly is set to CPU + bfloat16."""
180+
plugin = NativeMixedPrecisionPlugin("bf16", "cpu")
181+
assert plugin.device == "cpu"
182+
assert plugin.scaler is None
186183
context_manager = plugin.autocast_context_manager()
187184
assert isinstance(context_manager, torch.autocast)
188-
assert context_manager.fast_dtype == torch.bfloat16
185+
# check with str due to a bug upstream: https://github.com/pytorch/pytorch/issues/65786
186+
assert str(context_manager.fast_dtype) == str(torch.bfloat16)
189187

190188

191189
def test_precision_selection_raises(monkeypatch):

0 commit comments

Comments
 (0)