Skip to content

Commit 6e14209

Browse files
authored
Rewrite accelerator_connector (#11448)
1 parent a0ca8d0 commit 6e14209

32 files changed

+952
-1042
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ module = [
8282
"pytorch_lightning.profiler.pytorch",
8383
"pytorch_lightning.profiler.simple",
8484
"pytorch_lightning.trainer.callback_hook",
85-
"pytorch_lightning.trainer.connectors.accelerator_connector",
8685
"pytorch_lightning.trainer.connectors.callback_connector",
8786
"pytorch_lightning.trainer.connectors.data_connector",
8887
"pytorch_lightning.trainer.data_loading",

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import pytorch_lightning as pl
3131
from pytorch_lightning.callbacks.base import Callback
32-
from pytorch_lightning.utilities import _AcceleratorType
3332
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3433
from pytorch_lightning.utilities.parsing import AttributeDict
3534
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
@@ -127,7 +126,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
127126
if not trainer.logger:
128127
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")
129128

130-
if trainer._device_type != _AcceleratorType.GPU:
129+
if trainer.strategy.root_device.type != "cuda":
131130
raise MisconfigurationException(
132131
"You are using GPUStatsMonitor but are not running on GPU"
133132
f" since gpus attribute in Trainer is set to {trainer.gpus}."

pytorch_lightning/lite/lite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
self._check_strategy_support(strategy)
8383
gpu_ids, tpu_cores = _parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores)
8484
self._accelerator_connector = AcceleratorConnector(
85-
num_processes=1,
85+
num_processes=None,
8686
devices=devices,
8787
tpu_cores=tpu_cores,
8888
ipus=None,

pytorch_lightning/strategies/bagua.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from pytorch_lightning.strategies.ddp import DDPStrategy
1414
from pytorch_lightning.strategies.strategy import TBroadcast
1515
from pytorch_lightning.utilities.distributed import ReduceOp
16-
from pytorch_lightning.utilities.enums import _StrategyType
1716
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1817
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
1918
from pytorch_lightning.utilities.seed import reset_seed
@@ -58,7 +57,7 @@ def __init__(self, pl_module: "pl.LightningModule") -> None:
5857

5958

6059
class BaguaStrategy(DDPStrategy):
61-
distributed_backend = _StrategyType.BAGUA
60+
strategy_name = "bagua"
6261

6362
def __init__(
6463
self,
@@ -180,8 +179,12 @@ def _setup_model(self, model: Module) -> BaguaDistributedDataParallel:
180179
)
181180

182181
@classmethod
183-
def register_plugins(cls, plugin_registry: Dict) -> None:
184-
plugin_registry.register("bagua", cls, description="Default Bagua Plugin")
182+
def register_strategies(cls, strategy_registry: Dict) -> None:
183+
strategy_registry.register(
184+
cls.strategy_name,
185+
cls,
186+
description=f"{cls.__class__.__name__}",
187+
)
185188

186189
def teardown(self) -> None:
187190
# abort the background communication for async algorithm

pytorch_lightning/strategies/ddp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
4646
from pytorch_lightning.utilities.distributed import group as _group
4747
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
48-
from pytorch_lightning.utilities.enums import _StrategyType
4948
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
5049
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
5150
from pytorch_lightning.utilities.seed import reset_seed
@@ -63,7 +62,7 @@
6362
class DDPStrategy(ParallelStrategy):
6463
"""Strategy for multi-process single-device training on one or multiple nodes."""
6564

66-
distributed_backend = _StrategyType.DDP
65+
strategy_name = "ddp"
6766

6867
def __init__(
6968
self,
@@ -96,7 +95,6 @@ def __init__(
9695
self._pids: Optional[List[int]] = None
9796
self._sync_dir: Optional[str] = None
9897
self._rank_0_will_call_children_scripts: bool = False
99-
self.set_world_ranks()
10098

10199
@property
102100
def is_distributed(self) -> bool:
@@ -114,7 +112,6 @@ def num_nodes(self) -> int:
114112
def num_nodes(self, num_nodes: int) -> None:
115113
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
116114
self._num_nodes = num_nodes
117-
self.set_world_ranks()
118115

119116
@property
120117
def num_processes(self):
@@ -346,6 +343,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
346343
description="DDP Strategy with `find_unused_parameters` as False",
347344
find_unused_parameters=False,
348345
)
346+
strategy_registry.register(
347+
cls.strategy_name,
348+
cls,
349+
description=f"{cls.__class__.__name__}",
350+
)
349351

350352
def _should_run_deadlock_detection(self) -> bool:
351353
"""Determines whether the plugin will perform process reconciliation in case of errors.

pytorch_lightning/strategies/ddp2.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Dict
15+
1416
import torch
1517

1618
from pytorch_lightning.strategies.ddp import DDPStrategy
1719
from pytorch_lightning.utilities.apply_func import apply_to_collection
18-
from pytorch_lightning.utilities.enums import _StrategyType
1920
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
2021

2122

2223
class DDP2Strategy(DDPStrategy):
2324
"""DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP."""
2425

25-
distributed_backend = _StrategyType.DDP2
26+
strategy_name = "ddp2"
2627

2728
@property
2829
def global_rank(self) -> int:
@@ -73,3 +74,11 @@ def set_world_ranks(self) -> None:
7374
return
7475
self.cluster_environment.set_global_rank(self.node_rank)
7576
self.cluster_environment.set_world_size(self.num_nodes)
77+
78+
@classmethod
79+
def register_strategies(cls, strategy_registry: Dict) -> None:
80+
strategy_registry.register(
81+
cls.strategy_name,
82+
cls,
83+
description=f"{cls.__class__.__name__}",
84+
)

pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
3434
from pytorch_lightning.utilities.distributed import group as _group
3535
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
36-
from pytorch_lightning.utilities.enums import _StrategyType
3736
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
3837
from pytorch_lightning.utilities.seed import reset_seed
3938
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -48,7 +47,7 @@ class DDPSpawnStrategy(ParallelStrategy):
4847
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
4948
finishes."""
5049

51-
distributed_backend = _StrategyType.DDP_SPAWN
50+
strategy_name = "ddp_spawn"
5251

5352
def __init__(
5453
self,
@@ -76,7 +75,6 @@ def __init__(
7675
self._ddp_comm_hook = ddp_comm_hook
7776
self._ddp_comm_wrapper = ddp_comm_wrapper
7877
self._local_rank = 0
79-
self.set_world_ranks()
8078

8179
@property
8280
def num_nodes(self) -> int:
@@ -86,7 +84,6 @@ def num_nodes(self) -> int:
8684
def num_nodes(self, num_nodes: int) -> None:
8785
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
8886
self._num_nodes = num_nodes
89-
self.set_world_ranks()
9087

9188
@property
9289
def local_rank(self) -> int:
@@ -264,6 +261,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
264261
description="DDPSpawn Strategy with `find_unused_parameters` as False",
265262
find_unused_parameters=False,
266263
)
264+
strategy_registry.register(
265+
cls.strategy_name,
266+
cls,
267+
description=f"{cls.__class__.__name__}",
268+
)
267269

268270
def teardown(self) -> None:
269271
super().teardown()

pytorch_lightning/strategies/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pytorch_lightning.utilities import GradClipAlgorithmType
3636
from pytorch_lightning.utilities.apply_func import apply_to_collection
3737
from pytorch_lightning.utilities.distributed import log
38-
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
38+
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
3939
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4040
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
4141
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -82,7 +82,7 @@ def _move_float_tensors_to_half(self, batch: Any):
8282

8383

8484
class DeepSpeedStrategy(DDPStrategy):
85-
distributed_backend = _StrategyType.DEEPSPEED
85+
strategy_name = "deepspeed"
8686
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"
8787

8888
def __init__(

pytorch_lightning/strategies/dp.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, List, Optional
14+
from typing import Any, Dict, List, Optional
1515

1616
import torch
1717
from torch.nn import DataParallel, Module
@@ -22,7 +22,6 @@
2222
from pytorch_lightning.plugins.precision import PrecisionPlugin
2323
from pytorch_lightning.strategies.parallel import ParallelStrategy
2424
from pytorch_lightning.utilities.apply_func import apply_to_collection
25-
from pytorch_lightning.utilities.enums import _StrategyType
2625
from pytorch_lightning.utilities.model_helpers import is_overridden
2726
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, STEP_OUTPUT
2827

@@ -31,7 +30,7 @@ class DataParallelStrategy(ParallelStrategy):
3130
"""Implements data-parallel training in a single process, i.e., the model gets replicated to each device and
3231
each gets a split of the data."""
3332

34-
distributed_backend = _StrategyType.DP
33+
strategy_name = "dp"
3534

3635
def __init__(
3736
self,
@@ -149,6 +148,14 @@ def training_step_end(self, output):
149148

150149
return output
151150

151+
@classmethod
152+
def register_strategies(cls, strategy_registry: Dict) -> None:
153+
strategy_registry.register(
154+
cls.strategy_name,
155+
cls,
156+
description=f"{cls.__class__.__name__}",
157+
)
158+
152159
def teardown(self) -> None:
153160
super().teardown()
154161
if self.root_device.type == "cuda":

pytorch_lightning/strategies/fully_sharded.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytorch_lightning.plugins.precision import PrecisionPlugin
2424
from pytorch_lightning.strategies.ddp import DDPStrategy
2525
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
26-
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
26+
from pytorch_lightning.utilities.enums import PrecisionType
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.types import STEP_OUTPUT
2929

@@ -36,7 +36,7 @@
3636

3737
class DDPFullyShardedStrategy(DDPStrategy):
3838

39-
distributed_backend = _StrategyType.DDP_FULLY_SHARDED
39+
strategy_name = "ddp_fully_sharded"
4040

4141
def __init__(
4242
self,
@@ -212,3 +212,9 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
212212
strategy_registry.register(
213213
"fsdp", cls, description="Fully sharded training with checkpointing the full state dict."
214214
)
215+
216+
strategy_registry.register(
217+
cls.strategy_name,
218+
cls,
219+
description=f"{cls.__class__.__name__}",
220+
)

pytorch_lightning/strategies/horovod.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import ExitStack
15-
from typing import Any, List, Optional, Tuple, Union
15+
from typing import Any, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -26,7 +26,6 @@
2626
from pytorch_lightning.utilities.distributed import distributed_available
2727
from pytorch_lightning.utilities.distributed import group as dist_group
2828
from pytorch_lightning.utilities.distributed import ReduceOp
29-
from pytorch_lightning.utilities.enums import _StrategyType
3029
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
3130
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3231

@@ -37,7 +36,7 @@
3736
class HorovodStrategy(ParallelStrategy):
3837
"""Plugin for Horovod distributed training integration."""
3938

40-
distributed_backend = _StrategyType.HOROVOD
39+
strategy_name = "horovod"
4140

4241
def __init__(
4342
self,
@@ -196,6 +195,14 @@ def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tup
196195
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
197196
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
198197

198+
@classmethod
199+
def register_strategies(cls, strategy_registry: Dict) -> None:
200+
strategy_registry.register(
201+
cls.strategy_name,
202+
cls,
203+
description=f"{cls.__class__.__name__}",
204+
)
205+
199206
def teardown(self) -> None:
200207
super().teardown()
201208
# teardown may be called before `_exit_stack` is set

pytorch_lightning/strategies/ipu.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import json
1515
import os
16-
from typing import Any, Callable, List, Optional, Union
16+
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import torch
1919
from torch.utils.data import DataLoader
@@ -62,6 +62,8 @@ def _move_float_tensors_to_half(self, batch: Any) -> Any:
6262
class IPUStrategy(ParallelStrategy):
6363
"""Plugin for training on IPU devices."""
6464

65+
strategy_name = "ipu_strategy"
66+
6567
def __init__(
6668
self,
6769
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
@@ -360,3 +362,11 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
360362

361363
def broadcast(self, obj: object, src: int = 0) -> object:
362364
return obj
365+
366+
@classmethod
367+
def register_strategies(cls, strategy_registry: Dict) -> None:
368+
strategy_registry.register(
369+
cls.strategy_name,
370+
cls,
371+
description=f"{cls.__class__.__name__}",
372+
)

pytorch_lightning/strategies/sharded.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pytorch_lightning.core.optimizer import LightningOptimizer
2323
from pytorch_lightning.strategies.ddp import DDPStrategy
2424
from pytorch_lightning.trainer.states import TrainerFn
25-
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
25+
from pytorch_lightning.utilities.enums import PrecisionType
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
2828
from pytorch_lightning.utilities.rank_zero import rank_zero_only
@@ -37,7 +37,7 @@
3737
class DDPShardedStrategy(DDPStrategy):
3838
"""Optimizer and gradient sharded training provided by FairScale."""
3939

40-
distributed_backend = _StrategyType.DDP_SHARDED
40+
strategy_name = "ddp_sharded"
4141
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M
4242

4343
def configure_ddp(self) -> None:
@@ -135,3 +135,8 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
135135
description="DDP Sharded Strategy with `find_unused_parameters` as False",
136136
find_unused_parameters=False,
137137
)
138+
strategy_registry.register(
139+
cls.strategy_name,
140+
cls,
141+
description=f"{cls.__class__.__name__}",
142+
)

0 commit comments

Comments
 (0)