Skip to content

Commit a377331

Browse files
committed
pull from adrian's commit
1 parent db2feef commit a377331

File tree

8 files changed

+37
-16
lines changed

8 files changed

+37
-16
lines changed

pytorch_lightning/accelerators/cpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
3434
If the selected device is not CPU.
3535
"""
3636
if "cpu" not in str(self.root_device):
37-
raise MisconfigurationException(
38-
f"Device should be CPU, got {self.root_device} instead."
39-
)
37+
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")
4038

4139
return super().setup(trainer)
4240

pytorch_lightning/accelerators/gpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def setup_environment(self) -> None:
3838
"""
3939
super().setup_environment()
4040
if "cuda" not in str(self.root_device):
41-
raise MisconfigurationException(
42-
f"Device should be GPU, got {self.root_device} instead"
43-
)
41+
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
4442
torch.cuda.set_device(self.root_device)
4543

4644
def setup(self, trainer: "pl.Trainer") -> None:

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def __init__(
4242

4343
device = xm.xla_device(device)
4444
checkpoint_io = checkpoint_io or XLACheckpointIO()
45-
super().__init__(accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
45+
super().__init__(
46+
accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin
47+
)
4648

4749
self.debug = debug
4850
self.tpu_local_core_rank = 0

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def __init__(
6464
checkpoint_io = checkpoint_io or XLACheckpointIO()
6565
super().__init__(
6666
accelerator=accelerator,
67-
parallel_devices=parallel_devices, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin,
67+
parallel_devices=parallel_devices,
68+
checkpoint_io=checkpoint_io,
69+
precision_plugin=precision_plugin,
6870
)
6971
self.debug = debug
7072
self.tpu_local_core_rank = 0

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ class TrainingTypePlugin(ABC):
4242
loop."""
4343

4444
def __init__(
45-
self, accelerator: Optional["pl.Accelerator"] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
45+
self,
46+
accelerator: Optional["pl.Accelerator"] = None,
47+
checkpoint_io: Optional[CheckpointIO] = None,
48+
precision_plugin: Optional[PrecisionPlugin] = None,
4649
) -> None:
4750
self._accelerator = accelerator
4851
self._model: Optional[Module] = None

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,10 +691,16 @@ def select_precision_plugin(self) -> PrecisionPlugin:
691691

692692
def create_training_type_plugin(self) -> TrainingTypePlugin:
693693
if self.use_ddp2:
694-
plugin = DDP2Plugin(accelerator=self.accelerator, parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment,)
694+
plugin = DDP2Plugin(
695+
accelerator=self.accelerator,
696+
parallel_devices=self.parallel_devices,
697+
cluster_environment=self.cluster_environment,
698+
)
695699
elif self.use_ddp and self.use_deepspeed:
696700
plugin = DeepSpeedPlugin(
697-
accelerator=self.accelerator, cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices,
701+
accelerator=self.accelerator,
702+
cluster_environment=self.select_cluster_environment(),
703+
parallel_devices=self.parallel_devices,
698704
)
699705
elif self.use_ddp:
700706
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks()
@@ -733,7 +739,9 @@ def create_training_type_plugin(self) -> TrainingTypePlugin:
733739
ddp_plugin_cls = DDPPlugin
734740

735741
plugin = ddp_plugin_cls(
736-
accelerator=self.accelerator, parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment,
742+
accelerator=self.accelerator,
743+
parallel_devices=self.parallel_devices,
744+
cluster_environment=self.cluster_environment,
737745
)
738746
elif self.use_dp:
739747
plugin = DataParallelPlugin(accelerator=self.accelerator, parallel_devices=self.parallel_devices)
@@ -745,7 +753,10 @@ def create_training_type_plugin(self) -> TrainingTypePlugin:
745753
plugin = IPUPlugin(accelerator=self.accelerator, parallel_devices=self.parallel_devices)
746754
else:
747755
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
748-
plugin = SingleDevicePlugin(accelerator=self.accelerator, device=(torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu")),)
756+
plugin = SingleDevicePlugin(
757+
accelerator=self.accelerator,
758+
device=(torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu")),
759+
)
749760
return plugin
750761

751762
def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:

tests/accelerators/test_cpu.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
def test_restore_checkpoint_after_pre_dispatch_default():
1818
"""Assert default for restore_checkpoint_after_pre_dispatch is False."""
19-
plugin = SingleDevicePlugin(accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=PrecisionPlugin())
19+
plugin = SingleDevicePlugin(
20+
accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=PrecisionPlugin()
21+
)
2022
assert not plugin.restore_checkpoint_after_pre_dispatch
2123
assert not plugin.restore_checkpoint_after_pre_dispatch
2224

@@ -48,7 +50,12 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
4850
checkpoint_path = os.path.join(tmpdir, "model.pt")
4951
trainer.save_checkpoint(checkpoint_path)
5052

51-
plugin = TestPlugin(accelerator=CPUAccelerator(), precision_plugin=PrecisionPlugin(), device=torch.device("cpu"), checkpoint_io=TorchCheckpointIO())
53+
plugin = TestPlugin(
54+
accelerator=CPUAccelerator(),
55+
precision_plugin=PrecisionPlugin(),
56+
device=torch.device("cpu"),
57+
checkpoint_io=TorchCheckpointIO(),
58+
)
5259
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
5360

5461
trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True)

tests/accelerators/test_tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License
1414
import collections
1515
from copy import deepcopy
16-
from unittest.mock import patch, Mock
16+
from unittest.mock import Mock, patch
1717

1818
import pytest
1919
import torch

0 commit comments

Comments
 (0)