Skip to content

Standalone Lite: Accelerators #14578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
fe59302
add accelerator implementations to lite
awaelchli Sep 7, 2022
7271f94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
b6de11f
fix imports
awaelchli Sep 7, 2022
2ef04e6
rename registry argument
awaelchli Sep 7, 2022
9bbaf4f
fix test
awaelchli Sep 7, 2022
48bc1e8
fix tests
awaelchli Sep 7, 2022
0cf9651
Merge branch 'master' into lite/accelerators3
awaelchli Sep 7, 2022
dc09055
remove duplicated test
awaelchli Sep 7, 2022
6a14975
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
e6d619c
fix tests
awaelchli Sep 7, 2022
9055717
deprecation
awaelchli Sep 7, 2022
f016626
deprecations
awaelchli Sep 7, 2022
084bc6f
flake8
awaelchli Sep 7, 2022
9c19b48
fixes
awaelchli Sep 8, 2022
3d09dac
add mps to runif
awaelchli Sep 8, 2022
7a5a740
fix tests
awaelchli Sep 8, 2022
de78087
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2022
48ef646
Apply suggestions from code review
awaelchli Sep 9, 2022
6d60b96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
4e018c4
remove more
awaelchli Sep 9, 2022
983a6d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
2220350
local import
awaelchli Sep 9, 2022
cfce27e
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli Sep 9, 2022
4ba5809
undo device stats :(
awaelchli Sep 9, 2022
231d8c3
fix import
awaelchli Sep 9, 2022
6e1f03a
stupid typehints
awaelchli Sep 9, 2022
1505eb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
334e3cf
Merge branch 'master' into lite/accelerators
Borda Sep 9, 2022
e832e67
more refactors :(
awaelchli Sep 9, 2022
a90ef22
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli Sep 9, 2022
8bf889b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2022
1195cec
fix
awaelchli Sep 11, 2022
f4dd9a5
Merge branch 'master' into lite/accelerators3
awaelchli Sep 12, 2022
c1f029e
rename init_device to setup_device
awaelchli Sep 12, 2022
4cc08fe
remove unused import
awaelchli Sep 12, 2022
9b8572d
make uppercase to differentiate from class
awaelchli Sep 12, 2022
06bf069
trick test after moving import locally
awaelchli Sep 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/lightning_lite/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The PyTorch Lightning team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_lite.accelerators.accelerator import Accelerator # noqa: F401
from lightning_lite.accelerators.cpu import CPUAccelerator # noqa: F401
from lightning_lite.accelerators.cuda import CUDAAccelerator # noqa: F401
from lightning_lite.accelerators.mps import MPSAccelerator # noqa: F401
from lightning_lite.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from lightning_lite.accelerators.tpu import TPUAccelerator # noqa: F401

_ACCELERATORS_BASE_MODULE = "lightning_lite.accelerators"
AcceleratorRegistry = _AcceleratorRegistry()
call_register_accelerators(AcceleratorRegistry, _ACCELERATORS_BASE_MODULE)
56 changes: 56 additions & 0 deletions src/lightning_lite/accelerators/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict

import torch


class Accelerator(ABC):
"""The Accelerator base class.

An Accelerator is meant to deal with one type of hardware.
"""

@abstractmethod
def init_device(self, device: torch.device) -> None:
"""Create and prepare the device for the current process."""

@abstractmethod
def teardown(self) -> None:
"""Clean up any state created by the accelerator."""

@staticmethod
@abstractmethod
def parse_devices(devices: Any) -> Any:
"""Accelerator device parsing logic."""

@staticmethod
@abstractmethod
def get_parallel_devices(devices: Any) -> Any:
"""Gets parallel devices for the Accelerator."""

@staticmethod
@abstractmethod
def auto_device_count() -> int:
"""Get the device count when set to auto."""

@staticmethod
@abstractmethod
def is_available() -> bool:
"""Detect if the hardware is available."""

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
pass
65 changes: 65 additions & 0 deletions src/lightning_lite/accelerators/cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union

import torch

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities import device_parser


class CPUAccelerator(Accelerator):
"""Accelerator for CPU devices."""

def init_device(self, device: torch.device) -> None:
"""
Raises:
ValueError:
If the selected device is not CPU.
"""
if device.type != "cpu":
raise ValueError(f"Device should be CPU, got {device} instead.")

def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> int:
"""Accelerator device parsing logic."""
devices = device_parser.parse_cpu_cores(devices)
return devices

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = device_parser.parse_cpu_cores(devices)
return [torch.device("cpu")] * devices

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 1

@staticmethod
def is_available() -> bool:
"""CPU is always available for execution."""
return True

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"cpu",
cls,
description=cls.__class__.__name__,
)
64 changes: 64 additions & 0 deletions src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union

import torch

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities import device_parser


class CUDAAccelerator(Accelerator):
"""Accelerator for NVIDIA CUDA devices."""

def init_device(self, device: torch.device) -> None:
"""
Raises:
ValueError:
If the selected device is not of type CUDA.
"""
if device.type != "cuda":
raise ValueError(f"Device should be CUDA, got {device} instead.")
torch.cuda.set_device(device)

def teardown(self) -> None:
# clean up memory
torch.cuda.empty_cache()

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
return device_parser.parse_gpu_ids(devices, include_cuda=True)

@staticmethod
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return device_parser.num_cuda_devices()

@staticmethod
def is_available() -> bool:
return device_parser.num_cuda_devices() > 0

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"cuda",
cls,
description=cls.__class__.__name__,
)
75 changes: 75 additions & 0 deletions src/lightning_lite/accelerators/mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from functools import lru_cache
from typing import Dict, List, Optional, Union

import torch

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities import device_parser
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12


class MPSAccelerator(Accelerator):
"""Accelerator for Metal Apple Silicon GPU devices."""

def init_device(self, device: torch.device) -> None:
"""
Raises:
ValueError:
If the selected device is not MPS.
"""
if device.type != "mps":
raise ValueError(f"Device should be MPS, got {device} instead.")

def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
parsed_devices = device_parser.parse_gpu_ids(devices, include_mps=True)
return parsed_devices

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
parsed_devices = MPSAccelerator.parse_devices(devices)
assert parsed_devices is not None

return [torch.device("mps", i) for i in range(len(parsed_devices))]

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 1

@staticmethod
@lru_cache
def is_available() -> bool:
"""MPS is only available for certain torch builds starting at torch>=1.12, and is only enabled on a machine
with the ARM-based Apple Silicon processors.
"""
return (
_TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")
)

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"mps",
cls,
description=cls.__class__.__name__,
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from inspect import getmembers, isclass
from typing import Any, Callable, Dict, List, Optional

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities.exceptions import MisconfigurationException
from lightning_lite.utilities.registry import _is_register_method_overridden
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class _AcceleratorRegistry(dict):
Expand Down Expand Up @@ -112,11 +112,8 @@ def __str__(self) -> str:
return "Registered Accelerators: {}".format(", ".join(self.available_accelerators()))


AcceleratorRegistry = _AcceleratorRegistry()


def call_register_accelerators(base_module: str) -> None:
def call_register_accelerators(registry: _AcceleratorRegistry, base_module: str) -> None:
module = importlib.import_module(base_module)
for _, mod in getmembers(module, isclass):
if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"):
mod.register_accelerators(AcceleratorRegistry)
mod.register_accelerators(registry)
59 changes: 59 additions & 0 deletions src/lightning_lite/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union

import torch

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities import device_parser
from lightning_lite.utilities.imports import _TPU_AVAILABLE


class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices."""

def init_device(self, device: torch.device) -> None:
pass

def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]:
"""Accelerator device parsing logic."""
return device_parser.parse_tpu_cores(devices)

@staticmethod
def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]:
"""Gets parallel devices for the Accelerator."""
if isinstance(devices, int):
return list(range(devices))
return devices

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 8

@staticmethod
def is_available() -> bool:
return _TPU_AVAILABLE

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"tpu",
cls,
description=cls.__class__.__name__,
)
6 changes: 3 additions & 3 deletions src/lightning_lite/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def _get_all_available_mps_gpus() -> List[int]:
A list of all available MPS GPUs
"""
# lazy import to avoid circular dependencies
# from lightning_lite.accelerators.mps import _MPS_AVAILABLE
_MPS_AVAILABLE = False # TODO(lite): revert this once MPS utils have moved
return [0] if _MPS_AVAILABLE else []
from lightning_lite.accelerators.mps import MPSAccelerator

return [0] if MPSAccelerator.is_available() else []


def _get_all_available_cuda_gpus() -> List[int]:
Expand Down
1 change: 0 additions & 1 deletion src/lightning_lite/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
_HOROVOD_AVAILABLE = module_available("horovod.torch")
_OMEGACONF_AVAILABLE = package_available("omegaconf")
_POPTORCH_AVAILABLE = package_available("poptorch")
_PSUTIL_AVAILABLE = package_available("psutil")
_XLA_AVAILABLE: bool = package_available("torch_xla")

# TODO(lite): import this from the fairscale files once they move to lite package
Expand Down
Loading