Skip to content

Simplify some ddp-spawn tests #10921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 12 additions & 33 deletions tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
Expand All @@ -32,43 +31,23 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v


@RunIf(fairscale=True)
@pytest.mark.parametrize(["strategy"], [("ddp_sharded",), ("ddp_sharded_spawn",)])
def test_sharded_ddp_choice(tmpdir, strategy):
@pytest.mark.parametrize(
"strategy,expected", [("ddp_sharded", DDPShardedPlugin), ("ddp_sharded_spawn", DDPSpawnShardedPlugin)]
)
def test_sharded_ddp_choice(tmpdir, strategy, expected):
"""Test to ensure that plugin is correctly chosen."""

class CB(Callback):
def on_fit_start(self, trainer, pl_module):
if strategy == "ddp_sharded":
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
elif strategy == "ddp_sharded_spawn":
assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(fast_dev_run=True, strategy=strategy, callbacks=[CB()])

with pytest.raises(SystemExit):
trainer.fit(model)
trainer = Trainer(fast_dev_run=True, strategy=strategy)
assert isinstance(trainer.accelerator.training_type_plugin, expected)


@RunIf(min_gpus=1, fairscale=True)
@pytest.mark.parametrize(["strategy"], [("ddp_sharded",), ("ddp_sharded_spawn",)])
def test_ddp_choice_sharded_amp(tmpdir, strategy):
@pytest.mark.parametrize(
"strategy,expected", [("ddp_sharded", DDPShardedPlugin), ("ddp_sharded_spawn", DDPSpawnShardedPlugin)]
)
def test_ddp_choice_sharded_amp(tmpdir, strategy, expected):
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""

class CB(Callback):
def on_fit_start(self, trainer, pl_module):
if strategy == "ddp_sharded":
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
elif strategy == "ddp_sharded_spawn":
assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy, callbacks=[CB()])

with pytest.raises(SystemExit):
trainer.fit(model)
trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy)
assert isinstance(trainer.accelerator.training_type_plugin, expected)


@RunIf(skip_windows=True, fairscale=True)
Expand Down
12 changes: 4 additions & 8 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,14 +1515,10 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *arg
def test_spawn_predict_return_predictions(_, __, accelerator):
"""Test that `return_predictions=True` raise a MisconfigurationException with spawn training type plugins."""
model = BoringModel()

def run(expected_plugin, **trainer_kwargs):
trainer = Trainer(**trainer_kwargs, fast_dev_run=True)
assert isinstance(trainer.training_type_plugin, expected_plugin)
with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"):
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True)

run(DDPSpawnPlugin, accelerator=accelerator, strategy="ddp_spawn", devices=2)
trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True)
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"):
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True)


@pytest.mark.parametrize("return_predictions", [None, False, True])
Expand Down