Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Commit b699108

Browse files
awaelchliRaalsky
authored andcommitted
Improve code quality in AcceleratorConnector._configure_slurm_ddp (Lightning-AI#10102)
1 parent e0bd1e2 commit b699108

File tree

4 files changed

+36
-40
lines changed

4 files changed

+36
-40
lines changed

pytorch_lightning/plugins/environments/slurm_environment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ class SLURMEnvironment(ClusterEnvironment):
2828
def creates_processes_externally(self) -> bool:
2929
return True
3030

31+
@staticmethod
32+
def detect() -> bool:
33+
"""Returns ``True`` if the current process was launched on a SLURM cluster."""
34+
return "SLURM_NTASKS" in os.environ
35+
3136
@property
3237
def main_address(self) -> str:
3338
# figure out the root node addr

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def __init__(
134134
self.precision = precision
135135
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
136136
self.amp_level = amp_level
137-
self._is_slurm_managing_tasks = False
138137

139138
self._precision_plugin: Optional[PrecisionPlugin] = None
140139
self._training_type_plugin: Optional[TrainingTypePlugin] = None
@@ -167,7 +166,6 @@ def __init__(
167166
self.handle_given_plugins()
168167
self._set_distrib_type_if_training_type_plugin_passed()
169168

170-
self._configure_slurm_ddp()
171169
self._cluster_environment = self.select_cluster_environment()
172170

173171
self.update_device_type_if_ipu_plugin()
@@ -703,15 +701,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
703701
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
704702
)
705703
elif self.use_ddp:
706-
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks
704+
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks()
707705
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
708706
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
709707
use_ddp_spawn = self._distrib_type == _StrategyType.DDP_SPAWN
710708
use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
711709
use_tpu_spawn = self.use_tpu and self._distrib_type == _StrategyType.TPU_SPAWN
712710
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
713711
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
714-
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks
712+
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks()
715713
use_ddp_sharded = self._distrib_type == _StrategyType.DDP_SHARDED
716714
use_ddp_sharded_spawn = self._distrib_type == _StrategyType.DDP_SHARDED_SPAWN
717715
use_ddp_fully_sharded = self._distrib_type == _StrategyType.DDP_FULLY_SHARDED
@@ -807,8 +805,9 @@ def select_accelerator(self) -> Accelerator:
807805
def select_cluster_environment(self) -> ClusterEnvironment:
808806
if self._cluster_environment is not None:
809807
return self._cluster_environment
810-
if self._is_slurm_managing_tasks:
808+
if self._is_slurm_managing_tasks():
811809
env = SLURMEnvironment()
810+
rank_zero_info("Multiprocessing is handled by SLURM.")
812811
elif TorchElasticEnvironment.is_using_torchelastic():
813812
env = TorchElasticEnvironment()
814813
elif KubeflowEnvironment.is_using_kubeflow():
@@ -990,34 +989,6 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
990989
elif self.has_gpu:
991990
self._device_type = DeviceType.GPU
992991

993-
def _configure_slurm_ddp(self):
994-
# extract SLURM flag vars
995-
# whenever we have the correct number of tasks, we let slurm manage processes
996-
# otherwise we launch the required number of processes
997-
if self.use_ddp or self.use_ddp2:
998-
num_requested_gpus = self.num_gpus * self.num_nodes
999-
num_slurm_tasks = 0
1000-
try:
1001-
num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
1002-
self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus
1003-
1004-
# enable slurm cpu
1005-
if num_requested_gpus == 0:
1006-
self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes
1007-
1008-
# in interactive mode we don't manage tasks
1009-
job_name = os.environ["SLURM_JOB_NAME"]
1010-
if job_name == "bash":
1011-
self._is_slurm_managing_tasks = False
1012-
1013-
except Exception:
1014-
# likely not on slurm, so set the slurm managed flag to false
1015-
self._is_slurm_managing_tasks = False
1016-
1017-
# notify user the that slurm is managing tasks
1018-
if self._is_slurm_managing_tasks:
1019-
rank_zero_info("Multi-processing is handled by Slurm.")
1020-
1021992
def _set_distrib_type_if_training_type_plugin_passed(self):
1022993
# This is required as when `TrainingTypePlugin` instance is passed to either `strategy`
1023994
# or `plugins` flag, `AcceleratorConnector.set_distributed_mode` is not required to be
@@ -1026,3 +997,24 @@ def _set_distrib_type_if_training_type_plugin_passed(self):
1026997
return
1027998
if self._training_type_plugin is not None:
1028999
self._distrib_type = getattr(self._training_type_plugin, "distributed_backend", None)
1000+
1001+
def _is_slurm_managing_tasks(self) -> bool:
1002+
"""Returns whether we let SLURM manage the processes or not.
1003+
1004+
Returns ``True`` if and only if these conditions match:
1005+
1006+
- A SLURM cluster is detected
1007+
- A distributed plugin is being used
1008+
- The process is not launching in interactive mode
1009+
- The number of tasks in SLURM matches the requested number of devices and nodes in the Trainer
1010+
"""
1011+
if (
1012+
(not self.use_ddp and not self.use_ddp2)
1013+
or not SLURMEnvironment.detect()
1014+
or os.environ.get("SLURM_JOB_NAME") == "bash" # in interactive mode we don't manage tasks
1015+
):
1016+
return False
1017+
1018+
total_requested_devices = (self.num_gpus or self.num_processes) * self.num_nodes
1019+
num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0)
1020+
return num_slurm_tasks == total_requested_devices

tests/accelerators/test_accelerator_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
103103
def test_accelerator_choice_ddp_slurm(set_device_mock, device_count_mock, setup_distributed_mock):
104104
class CB(Callback):
105105
def on_fit_start(self, trainer, pl_module):
106-
assert trainer._accelerator_connector._is_slurm_managing_tasks
106+
assert trainer._accelerator_connector._is_slurm_managing_tasks()
107107
assert isinstance(trainer.accelerator, GPUAccelerator)
108108
assert isinstance(trainer.training_type_plugin, DDPPlugin)
109109
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -136,7 +136,7 @@ def on_fit_start(self, trainer, pl_module):
136136
def test_accelerator_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock):
137137
class CB(Callback):
138138
def on_fit_start(self, trainer, pl_module):
139-
assert trainer._accelerator_connector._is_slurm_managing_tasks
139+
assert trainer._accelerator_connector._is_slurm_managing_tasks()
140140
assert isinstance(trainer.accelerator, GPUAccelerator)
141141
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
142142
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -323,7 +323,7 @@ def on_fit_start(self, trainer, pl_module):
323323
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
324324
class CB(Callback):
325325
def on_fit_start(self, trainer, pl_module):
326-
assert trainer._accelerator_connector._is_slurm_managing_tasks
326+
assert trainer._accelerator_connector._is_slurm_managing_tasks()
327327
assert isinstance(trainer.accelerator, CPUAccelerator)
328328
assert isinstance(trainer.training_type_plugin, DDPPlugin)
329329
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -791,7 +791,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock):
791791
def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy):
792792
class CB(Callback):
793793
def on_fit_start(self, trainer, pl_module):
794-
assert trainer._accelerator_connector._is_slurm_managing_tasks
794+
assert trainer._accelerator_connector._is_slurm_managing_tasks()
795795
assert isinstance(trainer.accelerator, GPUAccelerator)
796796
assert isinstance(trainer.training_type_plugin, DDPPlugin)
797797
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -824,7 +824,7 @@ def on_fit_start(self, trainer, pl_module):
824824
def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_distributed_mock, strategy):
825825
class CB(Callback):
826826
def on_fit_start(self, trainer, pl_module):
827-
assert trainer._accelerator_connector._is_slurm_managing_tasks
827+
assert trainer._accelerator_connector._is_slurm_managing_tasks()
828828
assert isinstance(trainer.accelerator, GPUAccelerator)
829829
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
830830
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -1008,7 +1008,7 @@ def on_fit_start(self, trainer, pl_module):
10081008
def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock, strategy):
10091009
class CB(Callback):
10101010
def on_fit_start(self, trainer, pl_module):
1011-
assert trainer._accelerator_connector._is_slurm_managing_tasks
1011+
assert trainer._accelerator_connector._is_slurm_managing_tasks()
10121012
assert isinstance(trainer.accelerator, CPUAccelerator)
10131013
assert isinstance(trainer.training_type_plugin, DDPPlugin)
10141014
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)

tests/models/test_restore.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,6 @@ def test_dp_resume(tmpdir):
500500

501501
# fit model
502502
trainer = Trainer(**trainer_options)
503-
trainer._is_slurm_managing_tasks = True
504503
trainer.fit(model, datamodule=dm)
505504

506505
# track epoch before saving. Increment since we finished the current epoch, don't want to rerun

0 commit comments

Comments
 (0)