Skip to content

Commit 254fb85

Browse files
committed
Rewrite accelerator_connector
1 parent 5b59c95 commit 254fb85

File tree

8 files changed

+482
-917
lines changed

8 files changed

+482
-917
lines changed

pytorch_lightning/strategies/ddp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class DDPStrategy(ParallelStrategy):
8080
devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes.
8181
"""
8282

83-
distributed_backend = _StrategyType.DDP
83+
distributed_backend = "ddp"
8484

8585
def __init__(
8686
self,
@@ -431,6 +431,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
431431
description="DDP Strategy with `find_unused_parameters` as False",
432432
find_unused_parameters=False,
433433
)
434+
strategy_registry.register(
435+
cls.distributed_backend,
436+
cls,
437+
description="Strategy",
438+
)
434439

435440
def _should_run_deadlock_detection(self) -> bool:
436441
"""Determines whether the plugin will perform process reconciliation in case of errors.

pytorch_lightning/strategies/ddp2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import torch
15+
from typing import Dict
1516

1617
from pytorch_lightning.strategies.ddp import DDPStrategy
1718
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -22,7 +23,7 @@
2223
class DDP2Strategy(DDPStrategy):
2324
"""DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP."""
2425

25-
distributed_backend = _StrategyType.DDP2
26+
distributed_backend = "ddp2"
2627

2728
@property
2829
def global_rank(self) -> int:
@@ -73,3 +74,11 @@ def set_world_ranks(self) -> None:
7374
return
7475
self.cluster_environment.set_global_rank(self.node_rank)
7576
self.cluster_environment.set_world_size(self.num_nodes)
77+
78+
@classmethod
79+
def register_strategies(cls, strategy_registry: Dict) -> None:
80+
strategy_registry.register(
81+
cls.distributed_backend,
82+
cls,
83+
description="Strategy",
84+
)

pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class DDPSpawnStrategy(ParallelStrategy):
5858
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
5959
finishes."""
6060

61-
distributed_backend = _StrategyType.DDP_SPAWN
61+
distributed_backend = "ddp_spawn"
6262

6363
def __init__(
6464
self,
@@ -369,6 +369,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
369369
description="DDPSpawn Strategy with `find_unused_parameters` as False",
370370
find_unused_parameters=False,
371371
)
372+
strategy_registry.register(
373+
cls.distributed_backend,
374+
cls,
375+
description="Strategy",
376+
)
372377

373378
def teardown(self) -> None:
374379
super().teardown()

pytorch_lightning/strategies/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _move_float_tensors_to_half(self, batch: Any):
8282

8383

8484
class DeepSpeedStrategy(DDPStrategy):
85-
distributed_backend = _StrategyType.DEEPSPEED
85+
distributed_backend = "deepspeed"
8686
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"
8787

8888
def __init__(

0 commit comments

Comments
 (0)