From 3d5ec22f216c15df3f7a61893c0dd35f17c6eec1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 27 Oct 2021 16:03:14 +0200 Subject: [PATCH 1/5] Raise exception for `strategy=ddp_cpu|tpu_spawn` --- .../trainer/connectors/accelerator_connector.py | 10 ++++++++++ tests/accelerators/test_accelerator_connector.py | 7 +++++++ 2 files changed, 17 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8271cf9bdc742..1d77b1b25537d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -288,6 +288,16 @@ def _handle_accelerator_and_strategy(self) -> None: f" also passed `Trainer(accelerator={self.distributed_backend!r})`." f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead." ) + if self.strategy is not None and self.strategy == DistributedType.TPU_SPAWN: + raise MisconfigurationException( + "`Trainer(strategy='tpu_spawn')` is not a valid strategy," + "you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead" + ) + if self.strategy is not None and self.strategy == DistributedType.DDP_CPU: + raise MisconfigurationException( + "`Trainer(strategy='ddp_cpu')` is not a valid strategy," + "you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead" + ) def _set_training_type_plugin(self) -> None: if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry: diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 7ad93b167794d..e7505dee56cb1 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -645,6 +645,13 @@ def test_exception_when_strategy_used_with_plugins(): Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn") +def test_exception_invalid_strategy(): + with pytest.raises(MisconfigurationException, match=r"strategy='ddp_cpu'\)` is not a valid"): + Trainer(strategy="ddp_cpu") + with pytest.raises(MisconfigurationException, match=r"strategy='tpu_spawn'\)` is not a valid"): + Trainer(strategy="tpu_spawn") + + @pytest.mark.parametrize( ["strategy", "plugin"], [ From a381998bd07277ad216a220b802e31d8f633255f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 27 Oct 2021 16:04:02 +0200 Subject: [PATCH 2/5] whitespace --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 1d77b1b25537d..338d2592e5b8a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -291,12 +291,12 @@ def _handle_accelerator_and_strategy(self) -> None: if self.strategy is not None and self.strategy == DistributedType.TPU_SPAWN: raise MisconfigurationException( "`Trainer(strategy='tpu_spawn')` is not a valid strategy," - "you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead" + " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead." ) if self.strategy is not None and self.strategy == DistributedType.DDP_CPU: raise MisconfigurationException( "`Trainer(strategy='ddp_cpu')` is not a valid strategy," - "you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead" + " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead." ) def _set_training_type_plugin(self) -> None: From beb5e3761fb3af2376d81d772c657c5bdb75a7c2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 27 Oct 2021 16:49:15 +0200 Subject: [PATCH 3/5] Fix tests --- tests/callbacks/test_early_stopping.py | 19 +++++++-- tests/trainer/test_trainer.py | 54 +++++++++----------------- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 20f435224fa76..2b4fe9f05eb87 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -389,15 +389,25 @@ def on_train_end(self) -> None: [ ([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, None, 1), ([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, None, 1), - pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_cpu", 2, **_NO_WIN), - pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_cpu", 2, **_NO_WIN), + pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, **_NO_WIN), + pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, **_NO_WIN), ([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, None, 1), ([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, None, 1), pytest.param( - [EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, "ddp_cpu", 2, **_NO_WIN + [EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], + 3, + True, + "ddp_spawn", + 2, + **_NO_WIN, ), pytest.param( - [EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, "ddp_cpu", 2, **_NO_WIN + [EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], + 3, + True, + "ddp_spawn", + 2, + **_NO_WIN, ), ], ) @@ -419,6 +429,7 @@ def test_multiple_early_stopping_callbacks( overfit_batches=0.20, max_epochs=20, strategy=strategy, + accelerator="cpu", num_processes=num_processes, ) trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a45bf105722cf..bfa96ffaa3cd8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1368,9 +1368,9 @@ def on_predict_epoch_end(self, trainer, pl_module, outputs): def predict( tmpdir, - strategy, - gpus, - num_processes, + strategy=None, + accelerator=None, + devices=None, model=None, plugins=None, datamodule=True, @@ -1391,8 +1391,8 @@ def predict( log_every_n_steps=1, enable_model_summary=False, strategy=strategy, - gpus=gpus, - num_processes=num_processes, + accelerator=accelerator, + devices=devices, plugins=plugins, enable_progress_bar=enable_progress_bar, callbacks=[cb, cb_1] if use_callbacks else [], @@ -1431,7 +1431,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): return super().predict_step(batch, batch_idx, dataloader_idx) with pytest.warns(UserWarning, match="predict returned None"): - predict(tmpdir, None, None, 1, model=CustomBoringModel(), use_callbacks=False) + predict(tmpdir, model=CustomBoringModel(), use_callbacks=False) def test_trainer_predict_grad(tmpdir): @@ -1440,7 +1440,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert batch.expand_as(batch).grad_fn is None return super().predict_step(batch, batch_idx, dataloader_idx) - predict(tmpdir, None, None, 1, model=CustomBoringModel(), use_callbacks=False) + predict(tmpdir, model=CustomBoringModel(), use_callbacks=False) x = torch.zeros(1, requires_grad=True) assert x.expand_as(x).grad_fn is not None @@ -1449,33 +1449,33 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): @pytest.mark.parametrize("enable_progress_bar", [False, True]) @pytest.mark.parametrize("datamodule", [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule, enable_progress_bar): - predict(tmpdir, None, None, 1, datamodule=datamodule, enable_progress_bar=enable_progress_bar) + predict(tmpdir, datamodule=datamodule, enable_progress_bar=enable_progress_bar) @RunIf(min_gpus=2, special=True) @pytest.mark.parametrize("num_gpus", [1, 2]) def test_trainer_predict_dp(tmpdir, num_gpus): - predict(tmpdir, "dp", num_gpus, None) + predict(tmpdir, strategy="dp", accelerator="gpu", devices=num_gpus) @RunIf(min_gpus=2, special=True, fairscale=True) def test_trainer_predict_ddp(tmpdir): - predict(tmpdir, "ddp", 2, None) + predict(tmpdir, strategy="ddp", accelerator="gpu", devices=2) @RunIf(min_gpus=2, skip_windows=True, special=True) def test_trainer_predict_ddp_spawn(tmpdir): - predict(tmpdir, "ddp_spawn", 2, None) + predict(tmpdir, strategy="dp", accelerator="gpu", devices=2) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=1, special=True) def test_trainer_predict_1_gpu(tmpdir): - predict(tmpdir, None, 1, None) + predict(tmpdir, accelerator="gpu", devices=1) @RunIf(skip_windows=True) def test_trainer_predict_ddp_cpu(tmpdir): - predict(tmpdir, "ddp_cpu", 0, 2) + predict(tmpdir, strategy="ddp_spawn", accelerator="cpu", devices=2) @pytest.mark.parametrize("dataset_cls", [RandomDataset, RandomIterableDatasetWithLen, RandomIterableDataset]) @@ -1501,7 +1501,8 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *arg @patch("torch.cuda.device_count", return_value=2) @patch("torch.cuda.is_available", return_value=True) -def test_spawn_predict_return_predictions(*_): +@pytest.mark.parametrize("accelerator", ("cpu", "gpu")) +def test_spawn_predict_return_predictions(_, __, accelerator): """Test that `return_predictions=True` raise a MisconfigurationException with spawn training type plugins.""" model = BoringModel() @@ -1511,8 +1512,7 @@ def run(expected_plugin, **trainer_kwargs): with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"): trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True) - run(DDPSpawnPlugin, strategy="ddp_spawn", gpus=2) - run(DDPSpawnPlugin, strategy="ddp_cpu", num_processes=2) + run(DDPSpawnPlugin, accelerator=accelerator, strategy="ddp_spawn", devices=2) @pytest.mark.parametrize("return_predictions", [None, False, True]) @@ -1809,7 +1809,7 @@ def on_predict_start(self) -> None: @pytest.mark.parametrize( - "strategy,num_processes", [(None, 1), pytest.param("ddp_cpu", 2, marks=RunIf(skip_windows=True))] + "strategy,num_processes", [(None, 1), pytest.param("ddp_spawn", 2, marks=RunIf(skip_windows=True))] ) def test_model_in_correct_mode_during_stages(tmpdir, strategy, num_processes): model = TrainerStagesModel() @@ -1837,7 +1837,7 @@ def test_fit_test_synchronization(tmpdir): model = TestDummyModelForCheckpoint() checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="x", mode="min", save_top_k=1) trainer = Trainer( - default_root_dir=tmpdir, max_epochs=2, strategy="ddp_cpu", num_processes=2, callbacks=[checkpoint] + default_root_dir=tmpdir, max_epochs=2, strategy="ddp_spawn", num_processes=2, callbacks=[checkpoint] ) trainer.fit(model) assert os.path.exists(checkpoint.best_model_path), f"Could not find checkpoint at rank {trainer.global_rank}" @@ -2158,22 +2158,6 @@ def training_step(self, batch, batch_idx): dict(strategy="ddp_spawn", num_processes=1, gpus=None), dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), ), - ( - dict(strategy="ddp_cpu", num_processes=1, num_nodes=1, gpus=None), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), - ), - ( - dict(strategy="ddp_cpu", num_processes=2, num_nodes=1, gpus=None), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), - ), - ( - dict(strategy="ddp_cpu", num_processes=1, num_nodes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), - ), - ( - dict(strategy="ddp_cpu", num_processes=2, num_nodes=2, gpus=None), - dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), - ), ], ) def test_trainer_config_strategy(trainer_kwargs, expected, monkeypatch): From fba4e8871a19958d3cc350eaf1ee1e2a57c09953 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 14:51:48 +0200 Subject: [PATCH 4/5] Remove None check --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 338d2592e5b8a..e708e637659f6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -288,12 +288,12 @@ def _handle_accelerator_and_strategy(self) -> None: f" also passed `Trainer(accelerator={self.distributed_backend!r})`." f" HINT: Use just `Trainer(strategy={self.strategy!r})` instead." ) - if self.strategy is not None and self.strategy == DistributedType.TPU_SPAWN: + if self.strategy == DistributedType.TPU_SPAWN: raise MisconfigurationException( "`Trainer(strategy='tpu_spawn')` is not a valid strategy," " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead." ) - if self.strategy is not None and self.strategy == DistributedType.DDP_CPU: + if self.strategy == DistributedType.DDP_CPU: raise MisconfigurationException( "`Trainer(strategy='ddp_cpu')` is not a valid strategy," " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead." From 707f302f878f3a528c6c8bb04bfc687c7ab5bd1c Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 29 Oct 2021 20:37:26 +0530 Subject: [PATCH 5/5] Fix tpu test --- tests/accelerators/test_tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 4b2f56c329dee..c62aca058dbaf 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -229,7 +229,7 @@ def test_ddp_cpu_not_supported_on_tpus(): @RunIf(tpu=True) -@pytest.mark.parametrize("strategy", ["tpu_spawn", "tpu_spawn_debug"]) +@pytest.mark.parametrize("strategy", ["ddp_spawn", "tpu_spawn_debug"]) def test_strategy_choice_tpu_str(tmpdir, strategy): trainer = Trainer(strategy=strategy, accelerator="tpu", devices=8) assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)