Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Commit d76cce0

Browse files
four4fishRaalsky
authored andcommitted
Add more trainer config tests (Lightning-AI#10319)
* Add more trainer config tests * Add more trainer config and ttp register tests * Add more trainer config and ttp register tests
1 parent 0d85d77 commit d76cce0

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

tests/plugins/test_plugins_registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning import Trainer
1717
from pytorch_lightning.plugins import (
1818
CheckpointIO,
19+
DDPFullyShardedPlugin,
1920
DDPPlugin,
2021
DDPShardedPlugin,
2122
DDPSpawnPlugin,
@@ -93,6 +94,18 @@ def test_tpu_spawn_debug_plugins_registry(tmpdir):
9394
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
9495

9596

97+
def test_fsdp_plugins_registry(tmpdir):
98+
99+
plugin = "fsdp"
100+
101+
assert plugin in TrainingTypePluginsRegistry
102+
assert TrainingTypePluginsRegistry[plugin]["plugin"] == DDPFullyShardedPlugin
103+
104+
trainer = Trainer(strategy=plugin)
105+
106+
assert isinstance(trainer.training_type_plugin, DDPFullyShardedPlugin)
107+
108+
96109
@pytest.mark.parametrize(
97110
"plugin_name, plugin",
98111
[

tests/trainer/test_trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,6 +2158,15 @@ def training_step(self, batch, batch_idx):
21582158
dict(strategy="ddp_spawn", num_processes=1, gpus=None),
21592159
dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1),
21602160
),
2161+
(
2162+
dict(strategy="ddp_fully_sharded", gpus=1),
2163+
dict(
2164+
_distrib_type=DistributedType.DDP_FULLY_SHARDED,
2165+
_device_type=DeviceType.GPU,
2166+
num_gpus=1,
2167+
num_processes=1,
2168+
),
2169+
),
21612170
(
21622171
dict(strategy=DDPSpawnPlugin(), num_processes=2, gpus=None),
21632172
dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2),

0 commit comments

Comments
 (0)