@@ -100,18 +100,18 @@ class _Connector:
100
100
101
101
def __init__ (
102
102
self ,
103
- accelerator : Optional [ Union [str , Accelerator ]] = None ,
104
- strategy : Optional [ Union [str , Strategy ]] = None ,
105
- devices : Optional [ Union [List [int ], str , int ]] = None ,
103
+ accelerator : Union [str , Accelerator ] = "auto" ,
104
+ strategy : Union [str , Strategy ] = "auto" ,
105
+ devices : Union [List [int ], str , int ] = "auto" ,
106
106
num_nodes : int = 1 ,
107
107
precision : _PRECISION_INPUT = "32-true" ,
108
108
plugins : Optional [Union [_PLUGIN_INPUT , List [_PLUGIN_INPUT ]]] = None ,
109
109
) -> None :
110
110
111
111
# These arguments can be set through environment variables set by the CLI
112
- accelerator = self ._argument_from_env ("accelerator" , accelerator , default = None )
113
- strategy = self ._argument_from_env ("strategy" , strategy , default = None )
114
- devices = self ._argument_from_env ("devices" , devices , default = None )
112
+ accelerator = self ._argument_from_env ("accelerator" , accelerator , default = "auto" )
113
+ strategy = self ._argument_from_env ("strategy" , strategy , default = "auto" )
114
+ devices = self ._argument_from_env ("devices" , devices , default = "auto" )
115
115
num_nodes = self ._argument_from_env ("num_nodes" , num_nodes , default = 1 )
116
116
precision = self ._argument_from_env ("precision" , precision , default = "32-true" )
117
117
@@ -123,8 +123,8 @@ def __init__(
123
123
# Raise an exception if there are conflicts between flags
124
124
# Set each valid flag to `self._x_flag` after validation
125
125
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
126
- self ._strategy_flag : Optional [ Union [Strategy , str ]] = None
127
- self ._accelerator_flag : Optional [ Union [Accelerator , str ]] = None
126
+ self ._strategy_flag : Union [Strategy , str ] = "auto"
127
+ self ._accelerator_flag : Union [Accelerator , str ] = "auto"
128
128
self ._precision_input : _PRECISION_INPUT_STR = "32-true"
129
129
self ._precision_instance : Optional [Precision ] = None
130
130
self ._cluster_environment_flag : Optional [Union [ClusterEnvironment , str ]] = None
@@ -141,7 +141,7 @@ def __init__(
141
141
142
142
# 2. Instantiate Accelerator
143
143
# handle `auto`, `None` and `gpu`
144
- if self ._accelerator_flag == "auto" or self . _accelerator_flag is None :
144
+ if self ._accelerator_flag == "auto" :
145
145
self ._accelerator_flag = self ._choose_auto_accelerator ()
146
146
elif self ._accelerator_flag == "gpu" :
147
147
self ._accelerator_flag = self ._choose_gpu_accelerator_backend ()
@@ -152,7 +152,7 @@ def __init__(
152
152
self .cluster_environment : ClusterEnvironment = self ._choose_and_init_cluster_environment ()
153
153
154
154
# 4. Instantiate Strategy - Part 1
155
- if self ._strategy_flag is None :
155
+ if self ._strategy_flag == "auto" :
156
156
self ._strategy_flag = self ._choose_strategy ()
157
157
# In specific cases, ignore user selection and fall back to a different strategy
158
158
self ._check_strategy_and_fallback ()
@@ -166,8 +166,8 @@ def __init__(
166
166
167
167
def _check_config_and_set_final_flags (
168
168
self ,
169
- strategy : Optional [ Union [str , Strategy ] ],
170
- accelerator : Optional [ Union [str , Accelerator ] ],
169
+ strategy : Union [str , Strategy ],
170
+ accelerator : Union [str , Accelerator ],
171
171
precision : _PRECISION_INPUT ,
172
172
plugins : Optional [Union [_PLUGIN_INPUT , List [_PLUGIN_INPUT ]]],
173
173
) -> None :
@@ -188,26 +188,24 @@ def _check_config_and_set_final_flags(
188
188
if isinstance (strategy , str ):
189
189
strategy = strategy .lower ()
190
190
191
- if strategy is not None :
192
- self ._strategy_flag = strategy
191
+ self ._strategy_flag = strategy
193
192
194
- if strategy is not None and strategy not in self ._registered_strategies and not isinstance (strategy , Strategy ):
193
+ if strategy != "auto" and strategy not in self ._registered_strategies and not isinstance (strategy , Strategy ):
195
194
raise ValueError (
196
195
f"You selected an invalid strategy name: `strategy={ strategy !r} `."
197
196
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
198
- " Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
197
+ " Example choices: auto, ddp, ddp_spawn, deepspeed, dp, ..."
199
198
" Find a complete list of options in our documentation at https://lightning.ai"
200
199
)
201
200
202
201
if (
203
- accelerator is not None
204
- and accelerator not in self ._registered_accelerators
202
+ accelerator not in self ._registered_accelerators
205
203
and accelerator not in ("auto" , "gpu" )
206
204
and not isinstance (accelerator , Accelerator )
207
205
):
208
206
raise ValueError (
209
207
f"You selected an invalid accelerator name: `accelerator={ accelerator !r} `."
210
- f" Available names are: { ', ' .join (self ._registered_accelerators )} ."
208
+ f" Available names are: auto, { ', ' .join (self ._registered_accelerators )} ."
211
209
)
212
210
213
211
# MPS accelerator is incompatible with DDP family of strategies. It supports single-device operation only.
@@ -256,9 +254,9 @@ def _check_config_and_set_final_flags(
256
254
# handle the case when the user passes in a strategy instance which has an accelerator, precision,
257
255
# checkpoint io or cluster env set up
258
256
# TODO: improve the error messages below
259
- if self . _strategy_flag and isinstance (self ._strategy_flag , Strategy ):
257
+ if isinstance (self ._strategy_flag , Strategy ):
260
258
if self ._strategy_flag ._accelerator :
261
- if self ._accelerator_flag :
259
+ if self ._accelerator_flag != "auto" :
262
260
raise ValueError ("accelerator set through both strategy class and accelerator flag, choose one" )
263
261
else :
264
262
self ._accelerator_flag = self ._strategy_flag ._accelerator
@@ -297,9 +295,7 @@ def _check_config_and_set_final_flags(
297
295
self ._accelerator_flag = "cuda"
298
296
self ._parallel_devices = self ._strategy_flag .parallel_devices
299
297
300
- def _check_device_config_and_set_final_flags (
301
- self , devices : Optional [Union [List [int ], str , int ]], num_nodes : int
302
- ) -> None :
298
+ def _check_device_config_and_set_final_flags (self , devices : Union [List [int ], str , int ], num_nodes : int ) -> None :
303
299
self ._num_nodes_flag = int (num_nodes ) if num_nodes is not None else 1
304
300
self ._devices_flag = devices
305
301
@@ -314,21 +310,14 @@ def _check_device_config_and_set_final_flags(
314
310
f" using { accelerator_name } accelerator."
315
311
)
316
312
317
- if self ._devices_flag == "auto" and self ._accelerator_flag is None :
318
- raise ValueError (
319
- f"You passed `devices={ devices } ` but haven't specified"
320
- " `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
321
- )
322
-
323
313
def _choose_auto_accelerator (self ) -> str :
324
314
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
325
- if self ._accelerator_flag == "auto" :
326
- if TPUAccelerator .is_available ():
327
- return "tpu"
328
- if MPSAccelerator .is_available ():
329
- return "mps"
330
- if CUDAAccelerator .is_available ():
331
- return "cuda"
315
+ if TPUAccelerator .is_available ():
316
+ return "tpu"
317
+ if MPSAccelerator .is_available ():
318
+ return "mps"
319
+ if CUDAAccelerator .is_available ():
320
+ return "cuda"
332
321
return "cpu"
333
322
334
323
@staticmethod
@@ -337,7 +326,6 @@ def _choose_gpu_accelerator_backend() -> str:
337
326
return "mps"
338
327
if CUDAAccelerator .is_available ():
339
328
return "cuda"
340
-
341
329
raise RuntimeError ("No supported gpu backend found!" )
342
330
343
331
def _set_parallel_devices_and_init_accelerator (self ) -> None :
@@ -368,7 +356,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
368
356
self ._parallel_devices = accelerator_cls .get_parallel_devices (self ._devices_flag )
369
357
370
358
def _set_devices_flag_if_auto_passed (self ) -> None :
371
- if self ._devices_flag == "auto" or self . _devices_flag is None :
359
+ if self ._devices_flag == "auto" :
372
360
self ._devices_flag = self .accelerator .auto_device_count ()
373
361
374
362
def _choose_and_init_cluster_environment (self ) -> ClusterEnvironment :
@@ -527,7 +515,7 @@ def _lazy_init_strategy(self) -> None:
527
515
raise RuntimeError (
528
516
f"`Fabric(strategy={ self ._strategy_flag !r} )` is not compatible with an interactive"
529
517
" environment. Run your code as a script, or choose one of the compatible strategies:"
530
- f" `Fabric(strategy=None| 'dp'|'ddp_notebook')`."
518
+ f" `Fabric(strategy='dp'|'ddp_notebook')`."
531
519
" In case you are spawning processes yourself, make sure to include the Fabric"
532
520
" creation inside the worker function."
533
521
)
0 commit comments