Skip to content

Commit d43fd0d

Browse files
four4fishakihironittacarmocca
authored
Lazy initialize Strategy.parallel_devices (#11572)
Co-authored-by: Aki Nitta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent eceefdc commit d43fd0d

File tree

5 files changed

+19
-4
lines changed

5 files changed

+19
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
238238
- Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521))
239239

240240

241+
- Changed `parallel_devices` property in `ParallelStrategy` to be lazy initialized ([#11572](https://github.com/PyTorchLightning/pytorch-lightning/pull/11572))
242+
243+
241244
- Sorted `SimpleProfiler(extended=False)` summary based on mean duration for each hook ([#11671](https://github.com/PyTorchLightning/pytorch-lightning/pull/11671))
242245

243246

pytorch_lightning/strategies/ddp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def __init__(
106106
self.interactive_ddp_procs = []
107107
self._num_nodes = 1
108108
self.sync_batchnorm = False
109-
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
110109
self._ddp_kwargs = kwargs
111110
self._ddp_comm_state = ddp_comm_state
112111
self._ddp_comm_hook = ddp_comm_hook
@@ -135,6 +134,10 @@ def num_nodes(self, num_nodes: int) -> None:
135134
self._num_nodes = num_nodes
136135
self.set_world_ranks()
137136

137+
@property
138+
def num_processes(self):
139+
return len(self.parallel_devices) if self.parallel_devices is not None else 0
140+
138141
@property
139142
def distributed_sampler_kwargs(self):
140143
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)

pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def __init__(
8282
self._num_nodes = 1
8383
self.sync_batchnorm = False
8484
self._ddp_kwargs = kwargs
85-
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
8685
self._ddp_comm_state = ddp_comm_state
8786
self._ddp_comm_hook = ddp_comm_hook
8887
self._ddp_comm_wrapper = ddp_comm_wrapper
@@ -107,6 +106,10 @@ def local_rank(self) -> int:
107106
def root_device(self):
108107
return self.parallel_devices[self.local_rank]
109108

109+
@property
110+
def num_processes(self):
111+
return len(self.parallel_devices) if self.parallel_devices is not None else 0
112+
110113
@property
111114
def distributed_sampler_kwargs(self):
112115
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)

pytorch_lightning/strategies/parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def world_size(self) -> int:
7272
def is_global_zero(self) -> bool:
7373
return self.global_rank == 0
7474

75+
@property
76+
def parallel_devices(self):
77+
return self._parallel_devices
78+
79+
@parallel_devices.setter
80+
def parallel_devices(self, parallel_devices):
81+
self._parallel_devices = parallel_devices
82+
7583
@property
7684
def distributed_sampler_kwargs(self):
7785
distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank)

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,6 @@ def resolve_strategy(self, training_type: Strategy) -> Strategy:
759759
# necessary for when the user has passed in a plugin
760760
if hasattr(training_type, "parallel_devices") and getattr(training_type, "parallel_devices") is None:
761761
training_type.parallel_devices = self.parallel_devices
762-
if hasattr(training_type, "num_processes"):
763-
training_type.num_processes = len(self.parallel_devices)
764762

765763
if hasattr(training_type, "cluster_environment") and getattr(training_type, "cluster_environment") is None:
766764
# transfer ownership of the cluster environment to the training type

0 commit comments

Comments
 (0)