@@ -485,16 +485,15 @@ def _is_slurm_managing_tasks(self):
485
485
return num_slurm_tasks == total_requested_devices
486
486
487
487
def _choose_strategy (self ):
488
- if _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os .environ or "HOROVOD_RANK" in os .environ ):
489
- self ._strategy_flag = HorovodStrategy ()
490
-
491
488
if self ._accelerator_flag == "ipu" :
492
489
self ._strategy_flag = "ipu"
493
490
elif self ._accelerator_flag == "tpu" :
494
491
if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
495
492
self ._strategy_flag = "tpu_spawn"
496
493
else :
497
494
self ._srategy_flag = SingleTPUStrategy (device = self ._parallel_devices [0 ])
495
+ elif _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os .environ or "HOROVOD_RANK" in os .environ ):
496
+ self ._strategy_flag = HorovodStrategy ()
498
497
else :
499
498
if self ._num_nodes_flag > 1 :
500
499
self ._strategy_flag = "ddp"
@@ -549,8 +548,10 @@ def _strategy_check_and_fallbacks(self):
549
548
def _init_strategy (self ):
550
549
if isinstance (self ._strategy_flag , str ):
551
550
self .strategy = StrategyRegistry .get (self ._strategy_flag )
552
- else :
551
+ elif isinstance ( self . _strategy_flag , Strategy ) :
553
552
self .strategy = self ._strategy_flag
553
+ else :
554
+ raise RuntimeError (f"{ self .strategy } is not valid type: { self .strategy } " )
554
555
555
556
def handle_horovod (self ):
556
557
if self ._num_nodes_flag > 1 :
0 commit comments