@@ -134,7 +134,6 @@ def __init__(
134
134
self .precision = precision
135
135
self .amp_type = amp_type .lower () if isinstance (amp_type , str ) else None
136
136
self .amp_level = amp_level
137
- self ._is_slurm_managing_tasks = False
138
137
139
138
self ._precision_plugin : Optional [PrecisionPlugin ] = None
140
139
self ._training_type_plugin : Optional [TrainingTypePlugin ] = None
@@ -167,7 +166,6 @@ def __init__(
167
166
self .handle_given_plugins ()
168
167
self ._set_distrib_type_if_training_type_plugin_passed ()
169
168
170
- self ._configure_slurm_ddp ()
171
169
self ._cluster_environment = self .select_cluster_environment ()
172
170
173
171
self .update_device_type_if_ipu_plugin ()
@@ -703,15 +701,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
703
701
cluster_environment = self .select_cluster_environment (), parallel_devices = self .parallel_devices
704
702
)
705
703
elif self .use_ddp :
706
- use_slurm_ddp = self .use_ddp and self ._is_slurm_managing_tasks
704
+ use_slurm_ddp = self .use_ddp and self ._is_slurm_managing_tasks ()
707
705
use_torchelastic_ddp = self .use_ddp and TorchElasticEnvironment .is_using_torchelastic ()
708
706
use_kubeflow_ddp = self .use_ddp and KubeflowEnvironment .is_using_kubeflow ()
709
707
use_ddp_spawn = self ._distrib_type == _StrategyType .DDP_SPAWN
710
708
use_ddp_cpu_spawn = use_ddp_spawn and self .use_cpu
711
709
use_tpu_spawn = self .use_tpu and self ._distrib_type == _StrategyType .TPU_SPAWN
712
710
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment .is_using_torchelastic ()
713
711
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment .is_using_kubeflow ()
714
- use_ddp_cpu_slurm = use_ddp_cpu_spawn and self ._is_slurm_managing_tasks
712
+ use_ddp_cpu_slurm = use_ddp_cpu_spawn and self ._is_slurm_managing_tasks ()
715
713
use_ddp_sharded = self ._distrib_type == _StrategyType .DDP_SHARDED
716
714
use_ddp_sharded_spawn = self ._distrib_type == _StrategyType .DDP_SHARDED_SPAWN
717
715
use_ddp_fully_sharded = self ._distrib_type == _StrategyType .DDP_FULLY_SHARDED
@@ -807,8 +805,9 @@ def select_accelerator(self) -> Accelerator:
807
805
def select_cluster_environment (self ) -> ClusterEnvironment :
808
806
if self ._cluster_environment is not None :
809
807
return self ._cluster_environment
810
- if self ._is_slurm_managing_tasks :
808
+ if self ._is_slurm_managing_tasks () :
811
809
env = SLURMEnvironment ()
810
+ rank_zero_info ("Multiprocessing is handled by SLURM." )
812
811
elif TorchElasticEnvironment .is_using_torchelastic ():
813
812
env = TorchElasticEnvironment ()
814
813
elif KubeflowEnvironment .is_using_kubeflow ():
@@ -990,34 +989,6 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
990
989
elif self .has_gpu :
991
990
self ._device_type = DeviceType .GPU
992
991
993
- def _configure_slurm_ddp (self ):
994
- # extract SLURM flag vars
995
- # whenever we have the correct number of tasks, we let slurm manage processes
996
- # otherwise we launch the required number of processes
997
- if self .use_ddp or self .use_ddp2 :
998
- num_requested_gpus = self .num_gpus * self .num_nodes
999
- num_slurm_tasks = 0
1000
- try :
1001
- num_slurm_tasks = int (os .environ ["SLURM_NTASKS" ])
1002
- self ._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus
1003
-
1004
- # enable slurm cpu
1005
- if num_requested_gpus == 0 :
1006
- self ._is_slurm_managing_tasks = num_slurm_tasks == self .num_processes
1007
-
1008
- # in interactive mode we don't manage tasks
1009
- job_name = os .environ ["SLURM_JOB_NAME" ]
1010
- if job_name == "bash" :
1011
- self ._is_slurm_managing_tasks = False
1012
-
1013
- except Exception :
1014
- # likely not on slurm, so set the slurm managed flag to false
1015
- self ._is_slurm_managing_tasks = False
1016
-
1017
- # notify user the that slurm is managing tasks
1018
- if self ._is_slurm_managing_tasks :
1019
- rank_zero_info ("Multi-processing is handled by Slurm." )
1020
-
1021
992
def _set_distrib_type_if_training_type_plugin_passed (self ):
1022
993
# This is required as when `TrainingTypePlugin` instance is passed to either `strategy`
1023
994
# or `plugins` flag, `AcceleratorConnector.set_distributed_mode` is not required to be
@@ -1026,3 +997,24 @@ def _set_distrib_type_if_training_type_plugin_passed(self):
1026
997
return
1027
998
if self ._training_type_plugin is not None :
1028
999
self ._distrib_type = getattr (self ._training_type_plugin , "distributed_backend" , None )
1000
+
1001
+ def _is_slurm_managing_tasks (self ) -> bool :
1002
+ """Returns whether we let SLURM manage the processes or not.
1003
+
1004
+ Returns ``True`` if and only if these conditions match:
1005
+
1006
+ - A SLURM cluster is detected
1007
+ - A distributed plugin is being used
1008
+ - The process is not launching in interactive mode
1009
+ - The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer
1010
+ """
1011
+ if (
1012
+ (not self .use_ddp and not self .use_ddp2 )
1013
+ or not SLURMEnvironment .detect ()
1014
+ or os .environ .get ("SLURM_JOB_NAME" ) == "bash" # in interactive mode we don't manage tasks
1015
+ ):
1016
+ return False
1017
+
1018
+ total_requested_devices = (self .num_gpus or self .num_processes ) * self .num_nodes
1019
+ num_slurm_tasks = int (os .environ ["SLURM_NTASKS" ], 0 )
1020
+ return num_slurm_tasks == total_requested_devices
0 commit comments