From 80e15450926bd39fc6b762a7d5e37603658504ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 19 Feb 2022 15:38:45 +0100 Subject: [PATCH 1/6] fix is_interactive_compatible --- pytorch_lightning/strategies/launchers/base.py | 5 +++++ pytorch_lightning/strategies/launchers/spawn.py | 4 ++++ .../strategies/launchers/subprocess_script.py | 4 ++++ pytorch_lightning/strategies/launchers/xla_spawn.py | 4 ++++ .../trainer/connectors/accelerator_connector.py | 9 +++------ tests/accelerators/test_accelerator_connector.py | 10 +++++++++- 6 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 293c0a2ce4508..2acf54afef245 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -26,6 +26,11 @@ class _Launcher(ABC): cluster environment, hardware, strategy, etc. """ + @property + @abstractmethod + def is_interactive_compatible(self) -> bool: + """Returns whether this launcher can work in interactive environments such as Jupyter notebooks.""" + @abstractmethod def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Launches the processes.""" diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index d1349fd39cd97..19ff0527f03d7 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -48,6 +48,10 @@ class _SpawnLauncher(_Launcher): def __init__(self, strategy: Strategy) -> None: self._strategy = strategy + @property + def is_interactive_compatible(self) -> bool: + return False # TODO: the return value should depend on 1) start_method 2) CUDA vs. CPU + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. diff --git a/pytorch_lightning/strategies/launchers/subprocess_script.py b/pytorch_lightning/strategies/launchers/subprocess_script.py index e4b41500412d3..e482b6897a2a4 100644 --- a/pytorch_lightning/strategies/launchers/subprocess_script.py +++ b/pytorch_lightning/strategies/launchers/subprocess_script.py @@ -68,6 +68,10 @@ class _SubprocessScriptLauncher(_Launcher): num_nodes: The total number of nodes that participate in this process group. """ + @property + def is_interactive_compatible(self) -> bool: + return False + def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: super().__init__() self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 8bac7888c568b..520c6dff17ae2 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -44,6 +44,10 @@ class _XLASpawnLauncher(_SpawnLauncher): - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. """ + @property + def is_interactive_compatible(self) -> bool: + return True + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 20c5f485b4e71..4cda4acb090a7 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -715,19 +715,16 @@ def _lazy_init_strategy(self) -> None: from pytorch_lightning.utilities import _IS_INTERACTIVE - # TODO move is_compatible logic to strategy API - interactive_compatible_strategy = ( + interactive_recomended_strategy = ( DataParallelStrategy.strategy_name, - DDPSpawnStrategy.strategy_name, - DDPSpawnShardedStrategy.strategy_name, TPUSpawnStrategy.strategy_name, ) - if _IS_INTERACTIVE and self.strategy.strategy_name not in interactive_compatible_strategy: + if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible: raise MisconfigurationException( f"`Trainer(strategy={self.strategy.strategy_name!r})` or" f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible backends:" - f" {', '.join(interactive_compatible_strategy)}." + f" Trainer(strategy=None|{'|'.join(interactive_recomended_strategy)})." " In case you are spawning processes yourself, make sure to include the Trainer" " creation inside the worker function." ) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 76fa6d64f5a56..9c05560b33bd1 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -402,10 +402,18 @@ def test_ipython_incompatible_backend_error(*_): with pytest.raises(MisconfigurationException, match=r"strategy='ddp2'\)`.*is not compatible"): Trainer(strategy="ddp2", gpus=2) + with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): + # Edge case: AcceleratorConnector maps dp to ddp if no devices were selected + Trainer(strategy="dp") + @mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) def test_ipython_compatible_backend(*_): - Trainer(strategy="ddp_spawn", num_processes=2) + Trainer() + Trainer(strategy="dp", devices=2) + Trainer(strategy="ddp_spawn", devices=2) + Trainer(strategy="tpu_spawn", devices=2) + Trainer(strategy="coconut is not a nut", devices=2) @pytest.mark.parametrize(["accelerator", "plugin"], [("ddp_spawn", "ddp_sharded"), (None, "ddp_sharded")]) From eec3246c8b7635fa6ca11906d469b11c8b5748a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 19 Feb 2022 16:13:01 +0100 Subject: [PATCH 2/6] improve tests --- .../test_accelerator_connector.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 9c05560b33bd1..ab13d72213add 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -20,6 +20,7 @@ import torch import torch.distributed +import pytorch_lightning from pytorch_lightning import Trainer from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator @@ -393,27 +394,28 @@ def test_dist_backend_accelerator_mapping(*_): assert trainer.strategy.local_rank == 0 -@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) @mock.patch("torch.cuda.device_count", return_value=2) -def test_ipython_incompatible_backend_error(*_): +def test_ipython_incompatible_backend_error(_, monkeypatch): + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): Trainer(strategy="ddp", gpus=2) with pytest.raises(MisconfigurationException, match=r"strategy='ddp2'\)`.*is not compatible"): Trainer(strategy="ddp2", gpus=2) + with pytest.raises(MisconfigurationException, match=r"strategy='ddp_spawn'\)`.*is not compatible"): + Trainer(strategy="ddp_spawn") + with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): - # Edge case: AcceleratorConnector maps dp to ddp if no devices were selected + # Edge case: AcceleratorConnector maps dp to ddp if accelerator != gpu Trainer(strategy="dp") -@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) -def test_ipython_compatible_backend(*_): +def test_ipython_compatible_backend(monkeypatch): + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) Trainer() - Trainer(strategy="dp", devices=2) - Trainer(strategy="ddp_spawn", devices=2) - Trainer(strategy="tpu_spawn", devices=2) - Trainer(strategy="coconut is not a nut", devices=2) + Trainer(strategy="dp", accelerator="gpu") + Trainer(accelerator="tpu") @pytest.mark.parametrize(["accelerator", "plugin"], [("ddp_spawn", "ddp_sharded"), (None, "ddp_sharded")]) From 0cb5bac408cbfa5ad5efc3a586733e41f3addc12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 19 Feb 2022 16:20:20 +0100 Subject: [PATCH 3/6] update message --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 4cda4acb090a7..b283f46531afb 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -723,7 +723,7 @@ def _lazy_init_strategy(self) -> None: raise MisconfigurationException( f"`Trainer(strategy={self.strategy.strategy_name!r})` or" f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive" - " environment. Run your code as a script, or choose one of the compatible backends:" + " environment. Run your code as a script, or choose one of the compatible strategies:" f" Trainer(strategy=None|{'|'.join(interactive_recomended_strategy)})." " In case you are spawning processes yourself, make sure to include the Trainer" " creation inside the worker function." From b7355b242fc464835b17afff532b5069ed2c84d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 21 Feb 2022 23:38:17 +0100 Subject: [PATCH 4/6] address review --- .../trainer/connectors/accelerator_connector.py | 8 ++------ pytorch_lightning/utilities/enums.py | 2 -- tests/accelerators/test_accelerator_connector.py | 11 +++++++---- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index fee6ea4066d66..4cfb7959940b7 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -70,7 +70,7 @@ LightningEnum, rank_zero_deprecation, rank_zero_info, - rank_zero_warn, + rank_zero_warn, _StrategyType, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( @@ -734,16 +734,12 @@ def _lazy_init_strategy(self) -> None: from pytorch_lightning.utilities import _IS_INTERACTIVE - interactive_recomended_strategy = ( - DataParallelStrategy.strategy_name, - TPUSpawnStrategy.strategy_name, - ) if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible: raise MisconfigurationException( f"`Trainer(strategy={self.strategy.strategy_name!r})` or" f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive" " environment. Run your code as a script, or choose one of the compatible strategies:" - f" Trainer(strategy=None|{'|'.join(interactive_recomended_strategy)})." + f" Trainer(strategy=None|{'|'.join(_StrategyType.interactive_compatible_types())})." " In case you are spawning processes yourself, make sure to include the Trainer" " creation inside the worker function." ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 103fc87ecde1b..105b167a29910 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -254,8 +254,6 @@ def interactive_compatible_types() -> list[_StrategyType]: """Returns a list containing interactive compatible _StrategyTypes.""" return [ _StrategyType.DP, - _StrategyType.DDP_SPAWN, - _StrategyType.DDP_SHARDED_SPAWN, _StrategyType.TPU_SPAWN, ] diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 86f5286499e52..3eee5a085a354 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -406,16 +406,19 @@ def test_ipython_incompatible_backend_error(_, monkeypatch): with pytest.raises(MisconfigurationException, match=r"strategy='ddp_spawn'\)`.*is not compatible"): Trainer(strategy="ddp_spawn") + with pytest.raises(MisconfigurationException, match=r"strategy='ddp_sharded_spawn'\)`.*is not compatible"): + Trainer(strategy="ddp_sharded_spawn") + with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): # Edge case: AcceleratorConnector maps dp to ddp if accelerator != gpu Trainer(strategy="dp") -def test_ipython_compatible_backend(monkeypatch): +@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")]) +def test_ipython_compatible_backend(trainer_kwargs, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) - Trainer() - Trainer(strategy="dp", accelerator="gpu") - Trainer(accelerator="tpu") + trainer = Trainer(**trainer_kwargs) + assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible @pytest.mark.parametrize(["accelerator", "plugin"], [("ddp_spawn", "ddp_sharded"), (None, "ddp_sharded")]) From 961dd4b1bef9cd61258351357996ec18f6488c3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Feb 2022 22:39:43 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 4cfb7959940b7..f2d27a249f6f2 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -65,12 +65,13 @@ TPUSpawnStrategy, ) from pytorch_lightning.utilities import ( + _StrategyType, AMPType, device_parser, LightningEnum, rank_zero_deprecation, rank_zero_info, - rank_zero_warn, _StrategyType, + rank_zero_warn, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( From 7d6d7c315ea80db536c164c9d1f611f879e26e4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Feb 2022 13:19:24 +0100 Subject: [PATCH 6/6] add comments for spawn interactive --- pytorch_lightning/strategies/launchers/spawn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 128f1e0172740..d67f9e620a45d 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -51,7 +51,10 @@ def __init__(self, strategy: Strategy) -> None: @property def is_interactive_compatible(self) -> bool: - return False # TODO: the return value should depend on 1) start_method 2) CUDA vs. CPU + # The start method 'spawn' is currently the only one that works with DDP and CUDA support + # The start method 'fork' is the only one supported in Jupyter environments but not compatible with CUDA + # For more context, see https://github.com/PyTorchLightning/pytorch-lightning/issues/7550 + return self._start_method == "fork" and self._strategy.root_device.type != "cuda" def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel.