-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Standalone Lite: Accelerators #14578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
fe59302
add accelerator implementations to lite
awaelchli 7271f94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b6de11f
fix imports
awaelchli 2ef04e6
rename registry argument
awaelchli 9bbaf4f
fix test
awaelchli 48bc1e8
fix tests
awaelchli 0cf9651
Merge branch 'master' into lite/accelerators3
awaelchli dc09055
remove duplicated test
awaelchli 6a14975
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e6d619c
fix tests
awaelchli 9055717
deprecation
awaelchli f016626
deprecations
awaelchli 084bc6f
flake8
awaelchli 9c19b48
fixes
awaelchli 3d09dac
add mps to runif
awaelchli 7a5a740
fix tests
awaelchli de78087
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 48ef646
Apply suggestions from code review
awaelchli 6d60b96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4e018c4
remove more
awaelchli 983a6d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2220350
local import
awaelchli cfce27e
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli 4ba5809
undo device stats :(
awaelchli 231d8c3
fix import
awaelchli 6e1f03a
stupid typehints
awaelchli 1505eb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 334e3cf
Merge branch 'master' into lite/accelerators
Borda e832e67
more refactors :(
awaelchli a90ef22
Merge remote-tracking branch 'origin/lite/accelerators' into lite/acc…
awaelchli 8bf889b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1195cec
fix
awaelchli f4dd9a5
Merge branch 'master' into lite/accelerators3
awaelchli c1f029e
rename init_device to setup_device
awaelchli 4cc08fe
remove unused import
awaelchli 9b8572d
make uppercase to differentiate from class
awaelchli 06bf069
trick test after moving import locally
awaelchli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from lightning_lite.accelerators.accelerator import Accelerator # noqa: F401 | ||
from lightning_lite.accelerators.cpu import CPUAccelerator # noqa: F401 | ||
from lightning_lite.accelerators.cuda import CUDAAccelerator # noqa: F401 | ||
from lightning_lite.accelerators.mps import MPSAccelerator # noqa: F401 | ||
from lightning_lite.accelerators.registry import _AcceleratorRegistry, call_register_accelerators | ||
from lightning_lite.accelerators.tpu import TPUAccelerator # noqa: F401 | ||
|
||
_ACCELERATORS_BASE_MODULE = "lightning_lite.accelerators" | ||
AcceleratorRegistry = _AcceleratorRegistry() | ||
call_register_accelerators(AcceleratorRegistry, _ACCELERATORS_BASE_MODULE) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict | ||
|
||
import torch | ||
|
||
|
||
class Accelerator(ABC): | ||
"""The Accelerator base class. | ||
|
||
An Accelerator is meant to deal with one type of hardware. | ||
""" | ||
|
||
@abstractmethod | ||
def init_device(self, device: torch.device) -> None: | ||
"""Create and prepare the device for the current process.""" | ||
|
||
@abstractmethod | ||
def teardown(self) -> None: | ||
"""Clean up any state created by the accelerator.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def parse_devices(devices: Any) -> Any: | ||
"""Accelerator device parsing logic.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_parallel_devices(devices: Any) -> Any: | ||
"""Gets parallel devices for the Accelerator.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def auto_device_count() -> int: | ||
"""Get the device count when set to auto.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def is_available() -> bool: | ||
"""Detect if the hardware is available.""" | ||
|
||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Union | ||
|
||
import torch | ||
|
||
from lightning_lite.accelerators.accelerator import Accelerator | ||
from lightning_lite.utilities import device_parser | ||
|
||
|
||
class CPUAccelerator(Accelerator): | ||
"""Accelerator for CPU devices.""" | ||
|
||
def init_device(self, device: torch.device) -> None: | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Raises: | ||
ValueError: | ||
If the selected device is not CPU. | ||
""" | ||
if device.type != "cpu": | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError(f"Device should be CPU, got {device} instead.") | ||
|
||
def teardown(self) -> None: | ||
pass | ||
|
||
@staticmethod | ||
def parse_devices(devices: Union[int, str, List[int]]) -> int: | ||
"""Accelerator device parsing logic.""" | ||
devices = device_parser.parse_cpu_cores(devices) | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return devices | ||
|
||
@staticmethod | ||
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: | ||
"""Gets parallel devices for the Accelerator.""" | ||
devices = device_parser.parse_cpu_cores(devices) | ||
return [torch.device("cpu")] * devices | ||
|
||
@staticmethod | ||
def auto_device_count() -> int: | ||
"""Get the devices when set to auto.""" | ||
return 1 | ||
|
||
@staticmethod | ||
def is_available() -> bool: | ||
"""CPU is always available for execution.""" | ||
return True | ||
|
||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
accelerator_registry.register( | ||
"cpu", | ||
cls, | ||
description=cls.__class__.__name__, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Optional, Union | ||
|
||
import torch | ||
|
||
from lightning_lite.accelerators.accelerator import Accelerator | ||
from lightning_lite.utilities import device_parser | ||
|
||
|
||
class CUDAAccelerator(Accelerator): | ||
"""Accelerator for NVIDIA CUDA devices.""" | ||
|
||
def init_device(self, device: torch.device) -> None: | ||
""" | ||
Raises: | ||
ValueError: | ||
If the selected device is not of type CUDA. | ||
""" | ||
if device.type != "cuda": | ||
raise ValueError(f"Device should be CUDA, got {device} instead.") | ||
torch.cuda.set_device(device) | ||
|
||
def teardown(self) -> None: | ||
# clean up memory | ||
torch.cuda.empty_cache() | ||
|
||
@staticmethod | ||
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: | ||
"""Accelerator device parsing logic.""" | ||
return device_parser.parse_gpu_ids(devices, include_cuda=True) | ||
|
||
@staticmethod | ||
def get_parallel_devices(devices: List[int]) -> List[torch.device]: | ||
"""Gets parallel devices for the Accelerator.""" | ||
return [torch.device("cuda", i) for i in devices] | ||
|
||
@staticmethod | ||
def auto_device_count() -> int: | ||
"""Get the devices when set to auto.""" | ||
return device_parser.num_cuda_devices() | ||
|
||
@staticmethod | ||
def is_available() -> bool: | ||
return device_parser.num_cuda_devices() > 0 | ||
|
||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
accelerator_registry.register( | ||
"cuda", | ||
cls, | ||
description=cls.__class__.__name__, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import platform | ||
from functools import lru_cache | ||
from typing import Dict, List, Optional, Union | ||
|
||
import torch | ||
|
||
from lightning_lite.accelerators.accelerator import Accelerator | ||
from lightning_lite.utilities import device_parser | ||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 | ||
|
||
|
||
class MPSAccelerator(Accelerator): | ||
"""Accelerator for Metal Apple Silicon GPU devices.""" | ||
|
||
def init_device(self, device: torch.device) -> None: | ||
""" | ||
Raises: | ||
ValueError: | ||
If the selected device is not MPS. | ||
""" | ||
if device.type != "mps": | ||
raise ValueError(f"Device should be MPS, got {device} instead.") | ||
|
||
def teardown(self) -> None: | ||
pass | ||
|
||
@staticmethod | ||
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: | ||
"""Accelerator device parsing logic.""" | ||
parsed_devices = device_parser.parse_gpu_ids(devices, include_mps=True) | ||
return parsed_devices | ||
|
||
@staticmethod | ||
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: | ||
"""Gets parallel devices for the Accelerator.""" | ||
parsed_devices = MPSAccelerator.parse_devices(devices) | ||
assert parsed_devices is not None | ||
|
||
return [torch.device("mps", i) for i in range(len(parsed_devices))] | ||
|
||
@staticmethod | ||
def auto_device_count() -> int: | ||
"""Get the devices when set to auto.""" | ||
return 1 | ||
|
||
@staticmethod | ||
@lru_cache | ||
def is_available() -> bool: | ||
"""MPS is only available for certain torch builds starting at torch>=1.12, and is only enabled on a machine | ||
with the ARM-based Apple Silicon processors. | ||
""" | ||
return ( | ||
_TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64") | ||
) | ||
|
||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
accelerator_registry.register( | ||
"mps", | ||
cls, | ||
description=cls.__class__.__name__, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Optional, Union | ||
|
||
import torch | ||
|
||
from lightning_lite.accelerators.accelerator import Accelerator | ||
from lightning_lite.utilities import device_parser | ||
from lightning_lite.utilities.imports import _TPU_AVAILABLE | ||
|
||
|
||
class TPUAccelerator(Accelerator): | ||
"""Accelerator for TPU devices.""" | ||
|
||
def init_device(self, device: torch.device) -> None: | ||
pass | ||
|
||
def teardown(self) -> None: | ||
pass | ||
|
||
@staticmethod | ||
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: | ||
"""Accelerator device parsing logic.""" | ||
return device_parser.parse_tpu_cores(devices) | ||
|
||
@staticmethod | ||
def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]: | ||
"""Gets parallel devices for the Accelerator.""" | ||
if isinstance(devices, int): | ||
return list(range(devices)) | ||
return devices | ||
|
||
@staticmethod | ||
def auto_device_count() -> int: | ||
"""Get the devices when set to auto.""" | ||
return 8 | ||
|
||
@staticmethod | ||
def is_available() -> bool: | ||
return _TPU_AVAILABLE | ||
|
||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
accelerator_registry.register( | ||
"tpu", | ||
cls, | ||
description=cls.__class__.__name__, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.