Skip to content

Ensure accelerator is valid if running interactively #5970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ca6dcb7
added ipython env check
ifsheldon Feb 14, 2021
5ee0322
Merge branch 'master' into ipython_env_check
tchaton Feb 15, 2021
df6348d
extracted a method and added a compatibility list
ifsheldon Feb 15, 2021
c2d3a12
Merge branch 'master' into ipython_env_check
ifsheldon Feb 15, 2021
5ad468e
fixed fstring
ifsheldon Feb 16, 2021
46dcd8d
inverted if, added early return
ifsheldon Feb 16, 2021
e37bbd6
Merge branch 'ipython_env_check' of https://github.com/ifsheldon/pyto…
ifsheldon Feb 16, 2021
6739c36
changed to only checking ipython env
ifsheldon Feb 17, 2021
165dffc
Merge branch 'master' into ipython_env_check
awaelchli Feb 18, 2021
da99886
move import to top
ifsheldon Feb 18, 2021
799059c
Merge branch 'ipython_env_check' of https://github.com/ifsheldon/pyto…
ifsheldon Feb 18, 2021
a6eaec2
fix formatting, docstring and line length
awaelchli Feb 18, 2021
86f2404
Merge branch 'master' into ipython_env_check
awaelchli Feb 18, 2021
c0ff8ce
fix isort
awaelchli Feb 18, 2021
70b7dbf
add test
awaelchli Feb 18, 2021
53dc217
moved compatible list to enum method
ifsheldon Feb 19, 2021
e2098c6
merged remote changes
ifsheldon Feb 19, 2021
0676368
fixed a minor print issue
ifsheldon Feb 19, 2021
386e3a9
changed to use utilities.imports
ifsheldon Feb 19, 2021
4384880
added a method to check ipython compatibility
ifsheldon Feb 19, 2021
73b18b1
fixed a minor issue when _distrib_type is None
ifsheldon Feb 19, 2021
d2dedab
Fix test
carmocca Feb 22, 2021
4619606
Merge branch 'master' into ipython_env_check
carmocca Feb 22, 2021
990c03e
IPython -> interactive
carmocca Feb 23, 2021
0cbb515
Update tests/accelerators/test_accelerator_connector.py
carmocca Feb 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
self._distrib_type = None

# finished configuring self._distrib_type, check ipython environment
self.check_interactive_compatibility()

# for DDP overwrite nb processes by requested GPUs
if (
self._device_type == DeviceType.GPU
Expand Down Expand Up @@ -558,6 +561,19 @@ def _set_horovod_backend(self):
else:
self.num_processes = hvd.local_size()

def check_interactive_compatibility(self):
"""
Raises a `MisconfigurationException` if the accelerator and/or plugin
is not compatible with an interactive environment
"""
from pytorch_lightning.utilities import _IS_INTERACTIVE
if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible():
raise MisconfigurationException(
f"Selected distributed backend {self._distrib_type} is not compatible with an interactive"
" environment. Run your code as a script, or choose one of the compatible backends:"
f" {', '.join(DistributedType.interactive_compatible_types())}"
)

def check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not _HOROVOD_AVAILABLE:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""

import numpy

from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
Expand All @@ -33,6 +32,7 @@
_HOROVOD_AVAILABLE,
_HYDRA_AVAILABLE,
_HYDRA_EXPERIMENTAL_AVAILABLE,
_IS_INTERACTIVE,
_module_available,
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
Expand Down
16 changes: 13 additions & 3 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""Enumerated utilities"""
from enum import Enum
from typing import Union
from typing import List, Optional, Union


class LightningEnum(str, Enum):
""" Type of any enumerator with allowed comparison to string invariant to cases. """

@classmethod
def from_str(cls, value: str) -> 'LightningEnum':
def from_str(cls, value: str) -> Optional['LightningEnum']:
statuses = [status for status in dir(cls) if not status.startswith('_')]
for st in statuses:
if st.lower() == value.lower():
Expand All @@ -31,7 +31,7 @@ def __eq__(self, other: Union[str, Enum]) -> bool:
other = other.value if isinstance(other, Enum) else str(other)
return self.value.lower() == other.lower()

def __hash__(self):
def __hash__(self) -> int:
# re-enable hashtable so it can be used as a dict key or in a set
# example: set(LightningEnum)
return hash(self.name)
Expand All @@ -58,6 +58,16 @@ class DistributedType(LightningEnum):
>>> DistributedType.DDP2 in ('ddp2', )
True
"""

@staticmethod
def interactive_compatible_types() -> List['DistributedType']:
"""Returns a list containing interactive compatible DistributeTypes"""
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]

def is_interactive_compatible(self) -> bool:
"""Returns whether self is interactive compatible"""
return self in DistributedType.interactive_compatible_types()

DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""General utilities"""
import operator
import platform
import sys
from distutils.version import LooseVersion
from importlib.util import find_spec

Expand Down Expand Up @@ -49,10 +50,11 @@ def _compare_version(package: str, op, version) -> bool:


_IS_WINDOWS = platform.system() == "Windows"
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])

_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
Expand All @@ -65,6 +67,7 @@ def _compare_version(package: str, op, version) -> bool:
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")
15 changes: 15 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License

import os
import sys
from unittest import mock

import pytest
Expand All @@ -32,6 +33,7 @@
SingleDevicePlugin,
)
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -387,6 +389,19 @@ def on_fit_start(self, trainer, pl_module):
trainer.fit(model)


@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_ipython_incompatible_backend_error(*_):
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
Trainer(accelerator="ddp", gpus=2)

with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
Trainer(accelerator="ddp_cpu", num_processes=2)

with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"):
Trainer(accelerator="ddp2", gpus=2)


@pytest.mark.parametrize(
["accelerator", "plugin"],
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
Expand Down