Skip to content

Commit 21f7e79

Browse files
committed
Add more trainer config and ttp register tests
1 parent 258d6b2 commit 21f7e79

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

tests/plugins/test_plugins_registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning import Trainer
1717
from pytorch_lightning.plugins import (
1818
CheckpointIO,
19+
DDPFullyShardedPlugin,
1920
DDPPlugin,
2021
DDPShardedPlugin,
2122
DDPSpawnPlugin,
@@ -93,6 +94,18 @@ def test_tpu_spawn_debug_plugins_registry(tmpdir):
9394
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
9495

9596

97+
def test_fsdp_plugins_registry(tmpdir):
98+
99+
plugin = "fsdp"
100+
101+
assert plugin in TrainingTypePluginsRegistry
102+
assert TrainingTypePluginsRegistry[plugin]["plugin"] == DDPFullyShardedPlugin
103+
104+
trainer = Trainer(strategy=plugin)
105+
106+
assert isinstance(trainer.training_type_plugin, DDPFullyShardedPlugin)
107+
108+
96109
@pytest.mark.parametrize(
97110
"plugin_name, plugin",
98111
[

tests/trainer/test_trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,13 +2162,14 @@ def training_step(self, batch, batch_idx):
21622162
dict(strategy="ddp_spawn", num_processes=1, gpus=None),
21632163
dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1),
21642164
),
2165-
(
2166-
dict(strategy="fsdp", gpus=2),
2167-
dict(_distrib_type=DistributedType.DDP_FULLY_SHARDED, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1),
2168-
),
21692165
(
21702166
dict(strategy="ddp_fully_sharded", gpus=1),
2171-
dict(_distrib_type=DistributedType.DDP_FULLY_SHARDED, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1),
2167+
dict(
2168+
_distrib_type=DistributedType.DDP_FULLY_SHARDED,
2169+
_device_type=DeviceType.GPU,
2170+
num_gpus=1,
2171+
num_processes=1,
2172+
),
21722173
),
21732174
(
21742175
dict(strategy=DDPSpawnPlugin(), num_processes=2, gpus=None),

0 commit comments

Comments
 (0)