Skip to content

Commit 2294b1c

Browse files
committed
fix gpu setup environment
1 parent c3e0b5c commit 2294b1c

File tree

5 files changed

+24
-32
lines changed

5 files changed

+24
-32
lines changed

docs/source/extensions/accelerators.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ One to handle differences from the training routine and one to handle different
2626
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
2727

2828
accelerator = GPUAccelerator()
29-
trainer = Trainer(accelerator=accelerator)
29+
precision_plugin=NativeMixedPrecisionPlugin(precision=16, device="cuda")
30+
training_type_plugin=DDPPlugin(accelerator=accelerator, precision_plugin=precision_plugin)
31+
trainer = Trainer(strategy=training_type_plugin)
3032

3133

3234
We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Accelerator:
3030
- IPU
3131
"""
3232

33-
def setup_environment(self) -> None:
33+
def setup_environment(self, root_device: torch.device) -> None:
3434
"""Setup any processes or distributed connections.
3535
3636
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator

pytorch_lightning/accelerators/cpu.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,21 @@
1515

1616
import torch
1717

18-
# import pytorch_lightning as pl
1918
from pytorch_lightning.accelerators.accelerator import Accelerator
20-
21-
# from pytorch_lightning.utilities.exceptions import MisconfigurationException
19+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2220

2321

2422
class CPUAccelerator(Accelerator):
2523
"""Accelerator for CPU devices."""
2624

27-
# @property
28-
# def root_device(self):
29-
# return torch.device("cpu")
30-
31-
# def setup(self, trainer: "pl.Trainer") -> None:
32-
# """
33-
# Raises:
34-
# MisconfigurationException:
35-
# If the selected device is not CPU.
36-
# """
37-
# if "cpu" not in str(self.root_device):
38-
# raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")
39-
40-
# return super().setup(trainer)
25+
def setup_environment(self, root_device: torch.device) -> None:
26+
"""
27+
Raises:
28+
MisconfigurationException:
29+
If the selected device is not CPU.
30+
"""
31+
if "cpu" not in str(root_device):
32+
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")
4133

4234
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
4335
"""CPU device stats aren't supported yet."""

pytorch_lightning/accelerators/gpu.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
import pytorch_lightning as pl
2323
from pytorch_lightning.accelerators.accelerator import Accelerator
24-
25-
# from pytorch_lightning.utilities.exceptions import MisconfigurationException
24+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
2726

2827
_log = logging.getLogger(__name__)
@@ -31,19 +30,18 @@
3130
class GPUAccelerator(Accelerator):
3231
"""Accelerator for GPU devices."""
3332

34-
# def setup_environment(self) -> None:
35-
# """
36-
# Raises:
37-
# MisconfigurationException:
38-
# If the selected device is not GPU.
39-
# """
40-
# if "cuda" not in str(self.root_device):
41-
# raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
42-
# torch.cuda.set_device(self.root_device)
33+
def setup_environment(self, root_device: torch.device) -> None:
34+
"""
35+
Raises:
36+
MisconfigurationException:
37+
If the selected device is not GPU.
38+
"""
39+
if "cuda" not in str(root_device):
40+
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
41+
torch.cuda.set_device(root_device)
4342

4443
def setup(self, trainer: "pl.Trainer") -> None:
4544
self.set_nvidia_flags(trainer.local_rank)
46-
super().setup(trainer)
4745
# clear cache before training
4846
torch.cuda.empty_cache()
4947

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def setup_environment(self) -> None:
9393
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
9494
environment before setup is complete.
9595
"""
96-
self.accelerator.setup_environment()
96+
self.accelerator.setup_environment(self.root_device)
9797

9898
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
9999
"""Creates optimizers and schedulers.

0 commit comments

Comments
 (0)