Skip to content

Commit a0f80a6

Browse files
committed
debug tpu
1 parent e7c3bbf commit a0f80a6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,16 +485,15 @@ def _is_slurm_managing_tasks(self):
485485
return num_slurm_tasks == total_requested_devices
486486

487487
def _choose_strategy(self):
488-
if _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ):
489-
self._strategy_flag = HorovodStrategy()
490-
491488
if self._accelerator_flag == "ipu":
492489
self._strategy_flag = "ipu"
493490
elif self._accelerator_flag == "tpu":
494491
if self._parallel_devices and len(self._parallel_devices) > 1:
495492
self._strategy_flag = "tpu_spawn"
496493
else:
497494
self._srategy_flag = SingleTPUStrategy(device=self._parallel_devices[0])
495+
elif _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ):
496+
self._strategy_flag = HorovodStrategy()
498497
else:
499498
if self._num_nodes_flag > 1:
500499
self._strategy_flag = "ddp"
@@ -549,8 +548,10 @@ def _strategy_check_and_fallbacks(self):
549548
def _init_strategy(self):
550549
if isinstance(self._strategy_flag, str):
551550
self.strategy = StrategyRegistry.get(self._strategy_flag)
552-
else:
551+
elif isinstance(self._strategy_flag, Strategy):
553552
self.strategy = self._strategy_flag
553+
else:
554+
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
554555

555556
def handle_horovod(self):
556557
if self._num_nodes_flag > 1:

0 commit comments

Comments
 (0)