Skip to content

Commit 7c51ab7

Browse files
committed
support customized ttp and accelerator
1 parent 2294b1c commit 7c51ab7

File tree

5 files changed

+68
-60
lines changed

5 files changed

+68
-60
lines changed

pytorch_lightning/accelerators/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setup_environment(self, root_device: torch.device) -> None:
4141
torch.cuda.set_device(root_device)
4242

4343
def setup(self, trainer: "pl.Trainer") -> None:
44+
# TODO refactor input from trainer to local_rank @four4fish
4445
self.set_nvidia_flags(trainer.local_rank)
4546
# clear cache before training
4647
torch.cuda.empty_cache()

pytorch_lightning/accelerators/tpu.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515

1616
import torch
1717

18-
# import pytorch_lightning as pl
1918
from pytorch_lightning.accelerators.accelerator import Accelerator
20-
21-
# from pytorch_lightning.plugins.precision import TPUPrecisionPlugin
22-
# from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
23-
# from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
2419
from pytorch_lightning.utilities import _XLA_AVAILABLE
2520

2621
if _XLA_AVAILABLE:
@@ -30,25 +25,6 @@
3025
class TPUAccelerator(Accelerator):
3126
"""Accelerator for TPU devices."""
3227

33-
# def setup(self, trainer: "pl.Trainer") -> None:
34-
# """
35-
# Raises:
36-
# ValueError:
37-
# If the precision or training type plugin are unsupported.
38-
# """
39-
# if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
40-
# # this configuration should have been avoided in the accelerator connector
41-
# raise ValueError(
42-
# f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
43-
# f" found: {self.training_type_plugin.precision_plugin}."
44-
# )
45-
# if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
46-
# raise ValueError(
47-
# "The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin,"
48-
# f" found {self.training_type_plugin}."
49-
# )
50-
# return super().setup(trainer)
51-
5228
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
5329
"""Gets stats for the given TPU device.
5430

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def __init__(
178178
self.training_type_plugin = self.final_training_type_plugin()
179179
self.accelerator = self.training_type_plugin.accelerator
180180

181+
self._check_tpu_mis_config()
182+
181183
# benchmarking
182184
# TODO: should this be moved to GPU accelerator?
183185
torch.backends.cudnn.benchmark = self.benchmark
@@ -405,12 +407,19 @@ def final_training_type_plugin(self) -> TrainingTypePlugin:
405407
# attach checkpoint plugin to the training type plugin
406408
if self._checkpoint_io is not None:
407409
self._training_type_plugin.checkpoint_io = self._checkpoint_io
408-
precision_plugin = self.precision_plugin
409-
if precision_plugin is not None:
410-
self._training_type_plugin._precision_plugin = precision_plugin
410+
if (
411+
(hasattr(self.strategy, "precision_plugin") and self.precision_plugin is None)
412+
or not hasattr(self.strategy, "precision_plugin")
413+
):
414+
precision_plugin = self.precision_plugin
415+
if precision_plugin is not None:
416+
self._training_type_plugin._precision_plugin = precision_plugin
411417
self._training_type_plugin_resolved = True
412-
413-
self._training_type_plugin.accelerator = self.select_accelerator()
418+
if (
419+
(hasattr(self.strategy, "accelerator") and self.strategy.accelerator is None)
420+
or not hasattr(self.strategy, "accelerator")
421+
):
422+
self._training_type_plugin.accelerator = self.select_accelerator()
414423
return self._training_type_plugin
415424

416425
@property
@@ -1016,3 +1025,18 @@ def _is_slurm_managing_tasks(self) -> bool:
10161025
total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes
10171026
num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
10181027
return num_slurm_tasks == total_requested_devices
1028+
1029+
def _check_tpu_mis_config(self) -> None:
1030+
# TODO moved from TPUAccelerator when refactor accelerator. Revisit when refactor
1031+
# accelerator_connector @four4fish
1032+
if isinstance(self.accelerator, TPUAccelerator):
1033+
if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
1034+
raise ValueError(
1035+
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
1036+
f" found: {self.training_type_plugin.precision_plugin}."
1037+
)
1038+
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
1039+
raise ValueError(
1040+
"The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin,"
1041+
f" found {self.training_type_plugin}."
1042+
)

tests/accelerators/test_accelerator_connector.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,14 @@ def creates_processes_externally(self) -> bool:
397397

398398
@mock.patch.dict(
399399
os.environ,
400-
{"SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", "SLURM_LOCALID": "0"},
400+
{
401+
"SLURM_NTASKS": "2",
402+
"SLURM_JOB_NAME": "SOME_NAME",
403+
"SLURM_NODEID": "0",
404+
"LOCAL_RANK": "0",
405+
"SLURM_PROCID": "0",
406+
"SLURM_LOCALID": "0",
407+
},
401408
)
402409
@mock.patch("torch.cuda.device_count", return_value=0)
403410
@mock.patch("pytorch_lightning.plugins.DDPPlugin.setup_distributed", autospec=True)
@@ -408,28 +415,29 @@ class Accel(Accelerator):
408415
class Prec(PrecisionPlugin):
409416
pass
410417

411-
class TrainTypePlugin(SingleDevicePlugin):
418+
class TrainTypePlugin(DDPPlugin):
412419
pass
413420

421+
ttp = TrainTypePlugin(
422+
device=torch.device("cpu"),
423+
accelerator=Accel(),
424+
precision_plugin=Prec()
425+
)
426+
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
427+
assert isinstance(trainer.accelerator, Accel)
428+
assert isinstance(trainer.training_type_plugin, TrainTypePlugin)
429+
assert isinstance(trainer.precision_plugin, Prec)
430+
assert trainer._accelerator_connector.training_type_plugin is ttp
431+
432+
class DistributedPlugin(DDPPlugin):
433+
pass
414434

415-
# ttp = TrainTypePlugin(device=torch.device("cpu"))
416-
# accelerator = Accel(training_type_plugin=ttp, precision_plugin=Prec())
417-
# trainer = Trainer(accelerator=accelerator, fast_dev_run=True, num_processes=2)
418-
# assert isinstance(trainer.accelerator, Accel)
419-
# assert isinstance(trainer.training_type_plugin, TrainTypePlugin)
420-
# assert isinstance(trainer.precision_plugin, Prec)
421-
# assert trainer._accelerator_connector.training_type_plugin is ttp
422-
423-
# class DistributedPlugin(DDPPlugin):
424-
# pass
425-
426-
# ttp = DistributedPlugin()
427-
# accelerator = Accel(training_type_plugin=ttp, precision_plugin=Prec())
428-
# trainer = Trainer(accelerator=accelerator, fast_dev_run=True, num_processes=2)
429-
# assert isinstance(trainer.accelerator, Accel)
430-
# assert isinstance(trainer.training_type_plugin, DistributedPlugin)
431-
# assert isinstance(trainer.precision_plugin, Prec)
432-
# assert trainer._accelerator_connector.training_type_plugin is ttp
435+
ttp = DistributedPlugin(accelerator=Accel(), precision_plugin=Prec())
436+
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
437+
assert isinstance(trainer.accelerator, Accel)
438+
assert isinstance(trainer.training_type_plugin, DistributedPlugin)
439+
assert isinstance(trainer.precision_plugin, Prec)
440+
assert trainer._accelerator_connector.training_type_plugin is ttp
433441

434442

435443
@mock.patch.dict(

tests/accelerators/test_tpu.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -288,28 +288,27 @@ def forward(self, x):
288288

289289

290290
def test_tpu_invalid_raises():
291-
# TODO move TPUAccelerator() and CPUAccelerator() setup() misconfig logic into strategies
292291
training_type_plugin = TPUSpawnPlugin(accelerator=TPUAccelerator(), precision_plugin=Mock())
293-
# with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
294-
# training_type_plugin.setup(Mock())
292+
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
293+
Trainer(strategy=training_type_plugin)
295294

296295
training_type_plugin = DDPPlugin(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin())
297-
# with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`):
298-
# training_type_plugin.setup(Mock())
296+
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
297+
Trainer(strategy=training_type_plugin)
299298

300299

301300
def test_tpu_invalid_raises_set_precision_with_strategy():
302301
accelerator = TPUAccelerator()
303302
training_type_plugin = TPUSpawnPlugin(accelerator=accelerator, precision_plugin=object())
304-
# with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"):
305-
# training_type_plugin.setup(object())
303+
with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"):
304+
Trainer(strategy=training_type_plugin)
306305

307306
accelerator = TPUAccelerator()
308307
training_type_plugin = DDPPlugin(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin())
309-
# with pytest.raises(
310-
# ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin"
311-
# ):
312-
# training_type_plugin.setup(object())
308+
with pytest.raises(
309+
ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin"
310+
):
311+
Trainer(strategy=training_type_plugin)
313312

314313

315314
@RunIf(tpu=True)

0 commit comments

Comments
 (0)