@@ -171,8 +171,6 @@ def __init__(
171
171
# Reset strategy even user has specificed one
172
172
self ._strategy_check_and_fallbacks ()
173
173
self ._init_strategy ()
174
- if _HOROVOD_AVAILABLE and isinstance (self .strategy , HorovodStrategy ):
175
- self .handle_horovod
176
174
177
175
# --Precision----------------------------------------------------------------
178
176
self .precision_plugin = self ._check_capatibility_and_init_precision ()
@@ -411,7 +409,7 @@ def _choose_accelerator(self):
411
409
# self._existing_accelerator_type, [_TPU_AVAILABLE, _IPU_AVAILABLE, torch.cuda.is_available(), True]
412
410
# ):
413
411
# # only apply to gpu to keep backward compatibility
414
- # if self._accelerator_flag == accelerator_flag == "gpu" :
412
+ # if self._accelerator_flag == accelerator_flag:
415
413
# if not available:
416
414
# raise MisconfigurationException(
417
415
# f"You choice {accelerator_flag} accelerator, but {accelerator_flag} is not available"
@@ -491,9 +489,10 @@ def _choose_strategy(self):
491
489
if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
492
490
self ._strategy_flag = "tpu_spawn"
493
491
else :
494
- self ._srategy_flag = SingleTPUStrategy (device = self ._parallel_devices [0 ])
492
+ # TODO lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
493
+ self ._strategy_flag = SingleTPUStrategy (device = self ._parallel_devices [0 ])
495
494
elif _HOROVOD_AVAILABLE and ("OMPI_COMM_WORLD_RANK" in os .environ or "HOROVOD_RANK" in os .environ ):
496
- self ._strategy_flag = HorovodStrategy ()
495
+ self ._strategy_flag = "horovod"
497
496
else :
498
497
if self ._num_nodes_flag > 1 :
499
498
self ._strategy_flag = "ddp"
@@ -503,14 +502,15 @@ def _choose_strategy(self):
503
502
if self ._accelerator_flag == "gpu"
504
503
else "cpu"
505
504
)
505
+ # TODO lazy initialized device, then here could be self._strategy_flag = "single_device"
506
506
self ._strategy_flag = SingleDeviceStrategy (device = device )
507
507
elif len (self ._parallel_devices ) > 1 :
508
508
self ._strategy_flag = "ddp_spawn"
509
509
else :
510
510
self ._strategy_flag = "ddp"
511
511
512
512
def _strategy_check_and_fallbacks (self ):
513
- # fallback apply to user pass in object as well, so get the _strategy_flag first
513
+ # current logic, fallback only apply to user pass in str config not object config
514
514
_strategy_flag = "" if isinstance (self ._strategy_flag , Strategy ) else self ._strategy_flag
515
515
516
516
if _strategy_flag == "ddp_cpu" :
@@ -534,33 +534,18 @@ def _strategy_check_and_fallbacks(self):
534
534
if _strategy_flag in ("dp" , "ddp2" ) and self ._accelerator_flag == "cpu" :
535
535
rank_zero_warn (f"{ _strategy_flag !r} is not supported on CPUs, hence setting `strategy='ddp'`." )
536
536
_strategy_flag = "ddp"
537
- # Current test check precision first. So move this test to the end for now.
538
- # TODO update tests and uncomment this part
539
- # if isinstance(self.accelerator, TPUAccelerator) and "tpu" not in _strategy_flag:
540
- # raise ValueError(
541
- # "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy`,"
542
- # f" found {_strategy_flag}."
543
- # )
544
537
545
538
if _strategy_flag :
546
539
self ._strategy_flag = _strategy_flag
547
540
548
- def _init_strategy (self ):
549
- if isinstance (self ._strategy_flag , str ):
550
- self .strategy = StrategyRegistry .get (self ._strategy_flag )
551
- elif isinstance (self ._strategy_flag , Strategy ):
552
- self .strategy = self ._strategy_flag
553
- else :
554
- raise RuntimeError (f"{ self .strategy } is not valid type: { self .strategy } " )
555
-
556
541
def handle_horovod (self ):
557
542
if self ._num_nodes_flag > 1 :
558
543
raise MisconfigurationException (
559
544
"Horovod does not support setting num_nodes / num_gpus explicitly. Use "
560
545
"horovodrun / mpirun to configure the number of processes."
561
546
)
562
547
563
- if isinstance ( self . strategy , HorovodStrategy ) and not _HOROVOD_AVAILABLE :
548
+ if not _HOROVOD_AVAILABLE :
564
549
raise MisconfigurationException (
565
550
'Requested `accelerator="horovod"`, but Horovod is not installed.'
566
551
"Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]"
@@ -573,6 +558,18 @@ def handle_horovod(self):
573
558
else :
574
559
self ._parallel_device = hvd .local_size ()
575
560
561
+ def _init_strategy (self ):
562
+ if isinstance (self ._strategy_flag , HorovodStrategy ) or self ._strategy_flag == "horovod" :
563
+ # handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first.
564
+ # TODO lazy initialized and setup horovod strategy `global_rank`
565
+ self .handle_horovod ()
566
+ if isinstance (self ._strategy_flag , str ):
567
+ self .strategy = StrategyRegistry .get (self ._strategy_flag )
568
+ elif isinstance (self ._strategy_flag , Strategy ):
569
+ self .strategy = self ._strategy_flag
570
+ else :
571
+ raise RuntimeError (f"{ self .strategy } is not valid type: { self .strategy } " )
572
+
576
573
def _check_capatibility_and_init_precision (self ):
577
574
self ._precision_misconfig_check ()
578
575
if isinstance (self ._precision_flag , PrecisionPlugin ):
@@ -699,6 +696,8 @@ def _lazy_init_strategy(self):
699
696
" creation inside the worker function."
700
697
)
701
698
699
+ # TODO should be moved to _strategy_check_and_fallbacks().
700
+ # Current test check precision first, so keep this check here to meet error order
702
701
if isinstance (self .accelerator , TPUAccelerator ) and not isinstance (
703
702
self .strategy , (SingleTPUStrategy , TPUSpawnStrategy )
704
703
):
0 commit comments