@@ -112,7 +112,8 @@ def __init__(
112
112
self ._accelerator_type = None
113
113
114
114
self .strategy = strategy .lower () if isinstance (strategy , str ) else strategy
115
- self .accelerator = accelerator
115
+ # TODO: Rename this to something else once all the distributed flags are moved to strategy
116
+ self .distributed_backend = accelerator
116
117
117
118
self ._init_deterministic (deterministic )
118
119
@@ -202,7 +203,7 @@ def _init_deterministic(self, deterministic: bool) -> None:
202
203
os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"
203
204
204
205
def select_accelerator_type (self ) -> None :
205
- if self .accelerator == "auto" :
206
+ if self .distributed_backend == "auto" :
206
207
if self .has_tpu :
207
208
self ._accelerator_type = DeviceType .TPU
208
209
elif self .has_ipu :
@@ -212,34 +213,34 @@ def select_accelerator_type(self) -> None:
212
213
else :
213
214
self ._set_devices_to_cpu_num_processes ()
214
215
self ._accelerator_type = DeviceType .CPU
215
- elif self .accelerator == DeviceType .TPU :
216
+ elif self .distributed_backend == DeviceType .TPU :
216
217
if not self .has_tpu :
217
218
msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`"
218
219
raise MisconfigurationException (f"You passed `accelerator='tpu'`, but { msg } ." )
219
220
self ._accelerator_type = DeviceType .TPU
220
- elif self .accelerator == DeviceType .IPU :
221
+ elif self .distributed_backend == DeviceType .IPU :
221
222
if not self .has_ipu :
222
223
msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`"
223
224
raise MisconfigurationException (f"You passed `accelerator='ipu'`, but { msg } ." )
224
225
self ._accelerator_type = DeviceType .IPU
225
- elif self .accelerator == DeviceType .GPU :
226
+ elif self .distributed_backend == DeviceType .GPU :
226
227
if not self .has_gpu :
227
228
msg = "you didn't pass `gpus` to `Trainer`" if torch .cuda .is_available () else "GPUs are not available"
228
229
raise MisconfigurationException (f"You passed `accelerator='gpu'`, but { msg } ." )
229
230
self ._accelerator_type = DeviceType .GPU
230
- elif self .accelerator == DeviceType .CPU :
231
+ elif self .distributed_backend == DeviceType .CPU :
231
232
self ._set_devices_to_cpu_num_processes ()
232
233
self ._accelerator_type = DeviceType .CPU
233
234
234
- if self .accelerator in self .accelerator_types :
235
- self .accelerator = None
235
+ if self .distributed_backend in self .accelerator_types :
236
+ self .distributed_backend = None
236
237
237
238
def _validate_accelerator_and_devices (self ) -> None :
238
- if self .accelerator not in self .accelerator_types and self .devices is not None :
239
+ if self .distributed_backend not in self .accelerator_types and self .devices is not None :
239
240
raise MisconfigurationException (
240
241
f"You passed `devices={ self .devices } ` but haven't specified"
241
242
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping,"
242
- f" got `accelerator={ self .accelerator !r} `."
243
+ f" got `accelerator={ self .distributed_backend !r} `."
243
244
)
244
245
245
246
def _validate_accelerator_type (self ) -> None :
@@ -255,16 +256,16 @@ def _warn_if_devices_flag_ignored(self) -> None:
255
256
if self .devices is None :
256
257
return
257
258
devices_warning = f"The flag `devices={ self .devices } ` will be ignored, as you have set"
258
- if self .accelerator in ("auto" , DeviceType .TPU ):
259
+ if self .distributed_backend in ("auto" , DeviceType .TPU ):
259
260
if self .tpu_cores is not None :
260
261
rank_zero_warn (f"{ devices_warning } `tpu_cores={ self .tpu_cores } `" )
261
- elif self .accelerator in ("auto" , DeviceType .IPU ):
262
+ elif self .distributed_backend in ("auto" , DeviceType .IPU ):
262
263
if self .ipus is not None :
263
264
rank_zero_warn (f"{ devices_warning } `ipus={ self .ipus } `" )
264
- elif self .accelerator in ("auto" , DeviceType .GPU ):
265
+ elif self .distributed_backend in ("auto" , DeviceType .GPU ):
265
266
if self .gpus is not None :
266
267
rank_zero_warn (f"{ devices_warning } `gpus={ self .gpus } `" )
267
- elif self .accelerator in ("auto" , DeviceType .CPU ):
268
+ elif self .distributed_backend in ("auto" , DeviceType .CPU ):
268
269
if self .num_processes != 1 :
269
270
rank_zero_warn (f"{ devices_warning } `num_processes={ self .num_processes } `" )
270
271
@@ -281,15 +282,15 @@ def _set_devices_if_none(self) -> None:
281
282
self .devices = self .num_processes
282
283
283
284
def _handle_accelerator_and_strategy (self ) -> None :
284
- if self .accelerator is not None and self .accelerator in list (DistributedType ):
285
+ if self .distributed_backend is not None and self .distributed_backend in list (DistributedType ):
285
286
rank_zero_deprecation (
286
- f"Passing `Trainer(accelerator={ self .accelerator !r} )` has been deprecated"
287
- f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ self .accelerator !r} )` instead."
287
+ f"Passing `Trainer(accelerator={ self .distributed_backend !r} )` has been deprecated"
288
+ f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={ self .distributed_backend !r} )` instead."
288
289
)
289
290
if self .strategy is not None :
290
291
raise MisconfigurationException (
291
292
f"You have passed `Trainer(strategy={ self .strategy !r} )` but have"
292
- f" also passed `Trainer(accelerator={ self .accelerator !r} )`."
293
+ f" also passed `Trainer(accelerator={ self .distributed_backend !r} )`."
293
294
f" HINT: Use just `Trainer(strategy={ self .strategy !r} )` instead."
294
295
)
295
296
@@ -635,8 +636,11 @@ def select_precision_plugin(self) -> PrecisionPlugin:
635
636
return ApexMixedPrecisionPlugin (self .amp_level )
636
637
637
638
def select_training_type_plugin (self ) -> TrainingTypePlugin :
638
- if isinstance (self .accelerator , Accelerator ) and self .accelerator .training_type_plugin is not None :
639
- plugin = self .accelerator .training_type_plugin
639
+ if (
640
+ isinstance (self .distributed_backend , Accelerator )
641
+ and self .distributed_backend .training_type_plugin is not None
642
+ ):
643
+ plugin = self .distributed_backend .training_type_plugin
640
644
elif self .use_ddp2 :
641
645
plugin = DDP2Plugin (parallel_devices = self .parallel_devices , cluster_environment = self .cluster_environment )
642
646
elif self .use_ddp and self .use_deepspeed :
@@ -718,15 +722,15 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
718
722
return training_type
719
723
720
724
def select_accelerator (self ) -> Accelerator :
721
- if isinstance (self .accelerator , Accelerator ):
725
+ if isinstance (self .distributed_backend , Accelerator ):
722
726
# custom accelerator from user
723
727
if self ._precision_plugin is not None or self ._training_type_plugin is not None :
724
728
# plugins also specified by user
725
729
rank_zero_warn (
726
730
"Specified `Precision` and `TrainingType` plugins will be ignored,"
727
731
" since an `Accelerator` instance was provided."
728
732
)
729
- return self .accelerator
733
+ return self .distributed_backend
730
734
731
735
if self .use_gpu :
732
736
acc_cls = GPUAccelerator
@@ -766,32 +770,32 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
766
770
return
767
771
768
772
if strategy is not None and strategy in TrainingTypePluginsRegistry :
769
- self .accelerator = TrainingTypePluginsRegistry [strategy ]["distributed_backend" ]
773
+ self .distributed_backend = TrainingTypePluginsRegistry [strategy ]["distributed_backend" ]
770
774
elif strategy is not None :
771
- self .accelerator = strategy
775
+ self .distributed_backend = strategy
772
776
773
- if isinstance (self .accelerator , Accelerator ):
777
+ if isinstance (self .distributed_backend , Accelerator ):
774
778
return
775
779
776
780
is_cpu_accelerator_type = self ._accelerator_type and self ._accelerator_type == DeviceType .CPU
777
- _use_cpu = is_cpu_accelerator_type or self .accelerator and "cpu" in self .accelerator
781
+ _use_cpu = is_cpu_accelerator_type or self .distributed_backend and "cpu" in self .distributed_backend
778
782
779
- if self .accelerator is None :
783
+ if self .distributed_backend is None :
780
784
if self .has_horovodrun ():
781
785
self ._set_horovod_backend ()
782
786
elif self .num_gpus == 0 and self .num_nodes > 1 :
783
787
self ._distrib_type = DistributedType .DDP
784
788
elif self .num_gpus == 0 and self .num_processes > 1 :
785
- self .accelerator = DistributedType .DDP_SPAWN
789
+ self .distributed_backend = DistributedType .DDP_SPAWN
786
790
elif self .num_gpus > 1 and not _use_cpu :
787
791
rank_zero_warn (
788
792
"You requested multiple GPUs but did not specify a backend, e.g."
789
793
' `Trainer(strategy="dp"|"ddp"|"ddp2")`. Setting `strategy="ddp_spawn"` for you.'
790
794
)
791
- self .accelerator = DistributedType .DDP_SPAWN
795
+ self .distributed_backend = DistributedType .DDP_SPAWN
792
796
793
797
# special case with DDP on CPUs
794
- if self .accelerator == DistributedType .DDP_CPU :
798
+ if self .distributed_backend == DistributedType .DDP_CPU :
795
799
if _TPU_AVAILABLE :
796
800
raise MisconfigurationException (
797
801
"`accelerator='ddp_cpu'` is not supported on TPU machines. "
@@ -816,8 +820,8 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
816
820
self ._distrib_type = DistributedType .TPU_SPAWN
817
821
elif self .has_ipu and not _use_cpu :
818
822
self ._device_type = DeviceType .IPU
819
- elif self .accelerator and self ._distrib_type is None :
820
- self ._distrib_type = DistributedType (self .accelerator )
823
+ elif self .distributed_backend and self ._distrib_type is None :
824
+ self ._distrib_type = DistributedType (self .distributed_backend )
821
825
822
826
if self .num_gpus > 0 and not _use_cpu :
823
827
self ._device_type = DeviceType .GPU
@@ -850,7 +854,7 @@ def set_distributed_mode(self, strategy: Optional[str] = None):
850
854
self .num_processes = self .num_nodes
851
855
852
856
# Horovod is an extra case...
853
- if self .accelerator == DistributedType .HOROVOD :
857
+ if self .distributed_backend == DistributedType .HOROVOD :
854
858
self ._set_horovod_backend ()
855
859
856
860
using_valid_distributed = self .use_ddp or self .use_ddp2
0 commit comments