Skip to content

Commit d486f94

Browse files
authored
Fabric: auto default (#16842)
1 parent bc96513 commit d486f94

File tree

8 files changed

+147
-69
lines changed

8 files changed

+147
-69
lines changed

docs/source-pytorch/fabric/fundamentals/accelerators.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ Fabric enables you to take full advantage of the hardware on your system. It sup
1515
- GPU (NVIDIA, AMD, Apple Silicon)
1616
- TPU
1717

18-
By default, Fabric recognizes the accelerator(s) on your system
18+
By default, Fabric tries to maximize the hardware utilization of your system
1919

2020
.. code-block:: python
2121
2222
# Default settings
23-
fabric = Fabric(accelerator="auto", devices="auto")
23+
fabric = Fabric(accelerator="auto", devices="auto", strategy="auto")
2424
2525
# Same as
2626
fabric = Fabric()
@@ -40,7 +40,7 @@ You can also explicitly set which accelerator to use:
4040
fabric = Fabric(accelerator="gpu", devices=8)
4141
4242
# GPU: Apple M1/M2 only
43-
fabric = Fabric(accelerator="mps", devices=8)
43+
fabric = Fabric(accelerator="mps")
4444
4545
# GPU: NVIDIA CUDA only
4646
fabric = Fabric(accelerator="cuda", devices=8)

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
### Changed
2222

23+
- Fabric now chooses `accelerator="auto", strategy="auto", devices="auto"` as defaults ([#16842](https://github.com/Lightning-AI/lightning/pull/16842))
24+
25+
2326
- Checkpoint saving and loading redesign ([#16434](https://github.com/Lightning-AI/lightning/pull/16434))
2427
* Changed the method signatrue of `Fabric.save` and `Fabric.load`
2528
* Changed the method signature of `Strategy.save_checkpoint` and `Fabric.load_checkpoint`

src/lightning/fabric/connector.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,18 @@ class _Connector:
100100

101101
def __init__(
102102
self,
103-
accelerator: Optional[Union[str, Accelerator]] = None,
104-
strategy: Optional[Union[str, Strategy]] = None,
105-
devices: Optional[Union[List[int], str, int]] = None,
103+
accelerator: Union[str, Accelerator] = "auto",
104+
strategy: Union[str, Strategy] = "auto",
105+
devices: Union[List[int], str, int] = "auto",
106106
num_nodes: int = 1,
107107
precision: _PRECISION_INPUT = "32-true",
108108
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
109109
) -> None:
110110

111111
# These arguments can be set through environment variables set by the CLI
112-
accelerator = self._argument_from_env("accelerator", accelerator, default=None)
113-
strategy = self._argument_from_env("strategy", strategy, default=None)
114-
devices = self._argument_from_env("devices", devices, default=None)
112+
accelerator = self._argument_from_env("accelerator", accelerator, default="auto")
113+
strategy = self._argument_from_env("strategy", strategy, default="auto")
114+
devices = self._argument_from_env("devices", devices, default="auto")
115115
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
116116
precision = self._argument_from_env("precision", precision, default="32-true")
117117

@@ -123,8 +123,8 @@ def __init__(
123123
# Raise an exception if there are conflicts between flags
124124
# Set each valid flag to `self._x_flag` after validation
125125
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
126-
self._strategy_flag: Optional[Union[Strategy, str]] = None
127-
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
126+
self._strategy_flag: Union[Strategy, str] = "auto"
127+
self._accelerator_flag: Union[Accelerator, str] = "auto"
128128
self._precision_input: _PRECISION_INPUT_STR = "32-true"
129129
self._precision_instance: Optional[Precision] = None
130130
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
@@ -141,7 +141,7 @@ def __init__(
141141

142142
# 2. Instantiate Accelerator
143143
# handle `auto`, `None` and `gpu`
144-
if self._accelerator_flag == "auto" or self._accelerator_flag is None:
144+
if self._accelerator_flag == "auto":
145145
self._accelerator_flag = self._choose_auto_accelerator()
146146
elif self._accelerator_flag == "gpu":
147147
self._accelerator_flag = self._choose_gpu_accelerator_backend()
@@ -152,7 +152,7 @@ def __init__(
152152
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()
153153

154154
# 4. Instantiate Strategy - Part 1
155-
if self._strategy_flag is None:
155+
if self._strategy_flag == "auto":
156156
self._strategy_flag = self._choose_strategy()
157157
# In specific cases, ignore user selection and fall back to a different strategy
158158
self._check_strategy_and_fallback()
@@ -166,8 +166,8 @@ def __init__(
166166

167167
def _check_config_and_set_final_flags(
168168
self,
169-
strategy: Optional[Union[str, Strategy]],
170-
accelerator: Optional[Union[str, Accelerator]],
169+
strategy: Union[str, Strategy],
170+
accelerator: Union[str, Accelerator],
171171
precision: _PRECISION_INPUT,
172172
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]],
173173
) -> None:
@@ -188,26 +188,24 @@ def _check_config_and_set_final_flags(
188188
if isinstance(strategy, str):
189189
strategy = strategy.lower()
190190

191-
if strategy is not None:
192-
self._strategy_flag = strategy
191+
self._strategy_flag = strategy
193192

194-
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
193+
if strategy != "auto" and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
195194
raise ValueError(
196195
f"You selected an invalid strategy name: `strategy={strategy!r}`."
197196
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
198-
" Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
197+
" Example choices: auto, ddp, ddp_spawn, deepspeed, dp, ..."
199198
" Find a complete list of options in our documentation at https://lightning.ai"
200199
)
201200

202201
if (
203-
accelerator is not None
204-
and accelerator not in self._registered_accelerators
202+
accelerator not in self._registered_accelerators
205203
and accelerator not in ("auto", "gpu")
206204
and not isinstance(accelerator, Accelerator)
207205
):
208206
raise ValueError(
209207
f"You selected an invalid accelerator name: `accelerator={accelerator!r}`."
210-
f" Available names are: {', '.join(self._registered_accelerators)}."
208+
f" Available names are: auto, {', '.join(self._registered_accelerators)}."
211209
)
212210

213211
# MPS accelerator is incompatible with DDP family of strategies. It supports single-device operation only.
@@ -256,9 +254,9 @@ def _check_config_and_set_final_flags(
256254
# handle the case when the user passes in a strategy instance which has an accelerator, precision,
257255
# checkpoint io or cluster env set up
258256
# TODO: improve the error messages below
259-
if self._strategy_flag and isinstance(self._strategy_flag, Strategy):
257+
if isinstance(self._strategy_flag, Strategy):
260258
if self._strategy_flag._accelerator:
261-
if self._accelerator_flag:
259+
if self._accelerator_flag != "auto":
262260
raise ValueError("accelerator set through both strategy class and accelerator flag, choose one")
263261
else:
264262
self._accelerator_flag = self._strategy_flag._accelerator
@@ -297,9 +295,7 @@ def _check_config_and_set_final_flags(
297295
self._accelerator_flag = "cuda"
298296
self._parallel_devices = self._strategy_flag.parallel_devices
299297

300-
def _check_device_config_and_set_final_flags(
301-
self, devices: Optional[Union[List[int], str, int]], num_nodes: int
302-
) -> None:
298+
def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
303299
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
304300
self._devices_flag = devices
305301

@@ -314,21 +310,14 @@ def _check_device_config_and_set_final_flags(
314310
f" using {accelerator_name} accelerator."
315311
)
316312

317-
if self._devices_flag == "auto" and self._accelerator_flag is None:
318-
raise ValueError(
319-
f"You passed `devices={devices}` but haven't specified"
320-
" `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
321-
)
322-
323313
def _choose_auto_accelerator(self) -> str:
324314
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
325-
if self._accelerator_flag == "auto":
326-
if TPUAccelerator.is_available():
327-
return "tpu"
328-
if MPSAccelerator.is_available():
329-
return "mps"
330-
if CUDAAccelerator.is_available():
331-
return "cuda"
315+
if TPUAccelerator.is_available():
316+
return "tpu"
317+
if MPSAccelerator.is_available():
318+
return "mps"
319+
if CUDAAccelerator.is_available():
320+
return "cuda"
332321
return "cpu"
333322

334323
@staticmethod
@@ -337,7 +326,6 @@ def _choose_gpu_accelerator_backend() -> str:
337326
return "mps"
338327
if CUDAAccelerator.is_available():
339328
return "cuda"
340-
341329
raise RuntimeError("No supported gpu backend found!")
342330

343331
def _set_parallel_devices_and_init_accelerator(self) -> None:
@@ -368,7 +356,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
368356
self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag)
369357

370358
def _set_devices_flag_if_auto_passed(self) -> None:
371-
if self._devices_flag == "auto" or self._devices_flag is None:
359+
if self._devices_flag == "auto":
372360
self._devices_flag = self.accelerator.auto_device_count()
373361

374362
def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
@@ -527,7 +515,7 @@ def _lazy_init_strategy(self) -> None:
527515
raise RuntimeError(
528516
f"`Fabric(strategy={self._strategy_flag!r})` is not compatible with an interactive"
529517
" environment. Run your code as a script, or choose one of the compatible strategies:"
530-
f" `Fabric(strategy=None|'dp'|'ddp_notebook')`."
518+
f" `Fabric(strategy='dp'|'ddp_notebook')`."
531519
" In case you are spawning processes yourself, make sure to include the Fabric"
532520
" creation inside the worker function."
533521
)

src/lightning/fabric/fabric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ class Fabric:
7878

7979
def __init__(
8080
self,
81-
accelerator: Optional[Union[str, Accelerator]] = None,
82-
strategy: Optional[Union[str, Strategy]] = None,
83-
devices: Optional[Union[List[int], str, int]] = None,
81+
accelerator: Union[str, Accelerator] = "auto",
82+
strategy: Union[str, Strategy] = "auto",
83+
devices: Union[List[int], str, int] = "auto",
8484
num_nodes: int = 1,
8585
precision: _PRECISION_INPUT = "32-true",
8686
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,

tests/tests_fabric/conftest.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,26 @@ def reset_deterministic_algorithm():
7575
torch.use_deterministic_algorithms(False)
7676

7777

78+
def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
79+
monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", value)
80+
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
81+
monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", value)
82+
monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)
83+
84+
7885
@pytest.fixture(scope="function")
7986
def xla_available(monkeypatch: pytest.MonkeyPatch) -> None:
80-
monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", True)
81-
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", True)
82-
monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", True)
83-
monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", True)
87+
mock_xla_available(monkeypatch)
88+
89+
90+
def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
91+
mock_xla_available(monkeypatch, value)
92+
monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "is_available", lambda: value)
8493

8594

8695
@pytest.fixture(scope="function")
87-
def tpu_available(xla_available, monkeypatch) -> None:
88-
monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "is_available", lambda: True)
96+
def tpu_available(monkeypatch: pytest.MonkeyPatch) -> None:
97+
mock_tpu_available(monkeypatch)
8998

9099

91100
@pytest.fixture

tests/tests_fabric/plugins/precision/test_amp_integration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def after_backward(self, model):
6868
],
6969
)
7070
def test_amp(accelerator, precision, expected_dtype):
71-
fabric = MixedPrecisionBoringFabric(accelerator=accelerator, precision=precision)
71+
# TODO: devices>1 fails with:
72+
# DDP expects same model across all ranks, but Rank 0 has 2 params, while rank 1 has inconsistent 1 params
73+
fabric = MixedPrecisionBoringFabric(accelerator=accelerator, precision=precision, devices=1)
7274
fabric.expected_dtype = expected_dtype
7375
fabric.run()
7476

0 commit comments

Comments
 (0)