Skip to content

Commit 76081fb

Browse files
Mark SLURM detection methods in AcceleratorConnector as protected (#10101)
Co-authored-by: Justus Schock <[email protected]>
1 parent 2ee3127 commit 76081fb

File tree

5 files changed

+59
-18
lines changed

5 files changed

+59
-18
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
412412
- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))
413413

414414

415+
- Deprecated access to the `AcceleratorConnector.is_slurm_managing_tasks` attribute and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))
416+
417+
418+
- Deprecated access to the `AcceleratorConnector.configure_slurm_ddp` method and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))
419+
420+
415421
### Removed
416422

417423
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
self.precision = precision
136136
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
137137
self.amp_level = amp_level
138-
self.is_slurm_managing_tasks = False
138+
self._is_slurm_managing_tasks = False
139139

140140
self._precision_plugin: Optional[PrecisionPlugin] = None
141141
self._training_type_plugin: Optional[TrainingTypePlugin] = None
@@ -164,7 +164,7 @@ def __init__(
164164
self._set_training_type_plugin()
165165
else:
166166
self.set_distributed_mode()
167-
self.configure_slurm_ddp()
167+
self._configure_slurm_ddp()
168168

169169
self.handle_given_plugins()
170170
self.update_device_type_if_ipu_plugin()
@@ -685,15 +685,15 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
685685
cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
686686
)
687687
elif self.use_ddp:
688-
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
688+
use_slurm_ddp = self.use_ddp and self._is_slurm_managing_tasks
689689
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
690690
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
691691
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
692692
use_ddp_cpu_spawn = use_ddp_spawn and self.use_cpu
693693
use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN
694694
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
695695
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
696-
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
696+
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self._is_slurm_managing_tasks
697697
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
698698
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
699699
use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED
@@ -789,7 +789,7 @@ def select_accelerator(self) -> Accelerator:
789789
def select_cluster_environment(self) -> ClusterEnvironment:
790790
if self._cluster_environment is not None:
791791
return self._cluster_environment
792-
if self.is_slurm_managing_tasks:
792+
if self._is_slurm_managing_tasks:
793793
env = SLURMEnvironment()
794794
elif TorchElasticEnvironment.is_using_torchelastic():
795795
env = TorchElasticEnvironment()
@@ -972,7 +972,27 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
972972
elif self.has_gpu:
973973
self._device_type = DeviceType.GPU
974974

975-
def configure_slurm_ddp(self):
975+
@property
976+
def is_slurm_managing_tasks(self) -> bool:
977+
rank_zero_deprecation(
978+
"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5 and will be removed in v1.6."
979+
)
980+
return self._is_slurm_managing_tasks
981+
982+
@is_slurm_managing_tasks.setter
983+
def is_slurm_managing_tasks(self, value: bool) -> bool:
984+
rank_zero_deprecation(
985+
"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5 and will be removed in v1.6."
986+
)
987+
self._is_slurm_managing_tasks = value
988+
989+
def configure_slurm_ddp(self) -> None:
990+
rank_zero_deprecation(
991+
"`AcceleratorConnector.configure_slurm_ddp()` was deprecated in v1.5 and will be removed in v1.6."
992+
)
993+
self._configure_slurm_ddp()
994+
995+
def _configure_slurm_ddp(self):
976996
# extract SLURM flag vars
977997
# whenever we have the correct number of tasks, we let slurm manage processes
978998
# otherwise we launch the required number of processes
@@ -981,21 +1001,21 @@ def configure_slurm_ddp(self):
9811001
num_slurm_tasks = 0
9821002
try:
9831003
num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
984-
self.is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus
1004+
self._is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus
9851005

9861006
# enable slurm cpu
9871007
if num_requested_gpus == 0:
988-
self.is_slurm_managing_tasks = num_slurm_tasks == self.num_processes
1008+
self._is_slurm_managing_tasks = num_slurm_tasks == self.num_processes
9891009

9901010
# in interactive mode we don't manage tasks
9911011
job_name = os.environ["SLURM_JOB_NAME"]
9921012
if job_name == "bash":
993-
self.is_slurm_managing_tasks = False
1013+
self._is_slurm_managing_tasks = False
9941014

9951015
except Exception:
9961016
# likely not on slurm, so set the slurm managed flag to false
997-
self.is_slurm_managing_tasks = False
1017+
self._is_slurm_managing_tasks = False
9981018

9991019
# notify user the that slurm is managing tasks
1000-
if self.is_slurm_managing_tasks:
1020+
if self._is_slurm_managing_tasks:
10011021
rank_zero_info("Multi-processing is handled by Slurm.")

tests/accelerators/test_accelerator_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
100100
def test_accelerator_choice_ddp_slurm(setup_distributed_mock):
101101
class CB(Callback):
102102
def on_fit_start(self, trainer, pl_module):
103-
assert trainer.accelerator_connector.is_slurm_managing_tasks
103+
assert trainer.accelerator_connector._is_slurm_managing_tasks
104104
assert isinstance(trainer.accelerator, GPUAccelerator)
105105
assert isinstance(trainer.training_type_plugin, DDPPlugin)
106106
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -132,7 +132,7 @@ def on_fit_start(self, trainer, pl_module):
132132
def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):
133133
class CB(Callback):
134134
def on_fit_start(self, trainer, pl_module):
135-
assert trainer.accelerator_connector.is_slurm_managing_tasks
135+
assert trainer.accelerator_connector._is_slurm_managing_tasks
136136
assert isinstance(trainer.accelerator, GPUAccelerator)
137137
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
138138
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -307,7 +307,7 @@ def on_fit_start(self, trainer, pl_module):
307307
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
308308
class CB(Callback):
309309
def on_fit_start(self, trainer, pl_module):
310-
assert trainer.accelerator_connector.is_slurm_managing_tasks
310+
assert trainer.accelerator_connector._is_slurm_managing_tasks
311311
assert isinstance(trainer.accelerator, CPUAccelerator)
312312
assert isinstance(trainer.training_type_plugin, DDPPlugin)
313313
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -756,7 +756,7 @@ def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock):
756756
def test_strategy_choice_ddp_slurm(setup_distributed_mock):
757757
class CB(Callback):
758758
def on_fit_start(self, trainer, pl_module):
759-
assert trainer.accelerator_connector.is_slurm_managing_tasks
759+
assert trainer.accelerator_connector._is_slurm_managing_tasks
760760
assert isinstance(trainer.accelerator, GPUAccelerator)
761761
assert isinstance(trainer.training_type_plugin, DDPPlugin)
762762
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -788,7 +788,7 @@ def on_fit_start(self, trainer, pl_module):
788788
def test_strategy_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):
789789
class CB(Callback):
790790
def on_fit_start(self, trainer, pl_module):
791-
assert trainer.accelerator_connector.is_slurm_managing_tasks
791+
assert trainer.accelerator_connector._is_slurm_managing_tasks
792792
assert isinstance(trainer.accelerator, GPUAccelerator)
793793
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
794794
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
@@ -963,7 +963,7 @@ def on_fit_start(self, trainer, pl_module):
963963
def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):
964964
class CB(Callback):
965965
def on_fit_start(self, trainer, pl_module):
966-
assert trainer.accelerator_connector.is_slurm_managing_tasks
966+
assert trainer.accelerator_connector._is_slurm_managing_tasks
967967
assert isinstance(trainer.accelerator, CPUAccelerator)
968968
assert isinstance(trainer.training_type_plugin, DDPPlugin)
969969
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)

tests/deprecated_api/test_remove_1-6.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,18 @@ def test_v1_6_0_deprecated_accelerator_pass_through_functions():
406406

407407
with pytest.deprecated_call(match="will be removed in v1.6"):
408408
accelerator.on_train_batch_start(batch=None, batch_idx=0)
409+
410+
411+
def test_v1_6_0_configure_slurm_ddp():
412+
trainer = Trainer()
413+
with pytest.deprecated_call(match=r"`AcceleratorConnector.configure_slurm_ddp\(\)` was deprecated in v1.5"):
414+
trainer.accelerator_connector.configure_slurm_ddp()
415+
416+
417+
def test_v1_6_0_is_slurm_managing_tasks():
418+
trainer = Trainer()
419+
with pytest.deprecated_call(match=r"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5"):
420+
_ = trainer.accelerator_connector.is_slurm_managing_tasks
421+
422+
with pytest.deprecated_call(match=r"`AcceleratorConnector.is_slurm_managing_tasks` was deprecated in v1.5"):
423+
trainer.accelerator_connector.is_slurm_managing_tasks = False

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def test_dp_resume(tmpdir):
502502

503503
# fit model
504504
trainer = Trainer(**trainer_options)
505-
trainer.is_slurm_managing_tasks = True
505+
trainer._is_slurm_managing_tasks = True
506506
trainer.fit(model, datamodule=dm)
507507

508508
# track epoch before saving. Increment since we finished the current epoch, don't want to rerun

0 commit comments

Comments
 (0)