@@ -177,6 +177,7 @@ def __init__(
177
177
self ._training_type_plugin_resolved = False
178
178
self .training_type_plugin = self .final_training_type_plugin ()
179
179
self .accelerator = self .training_type_plugin .accelerator
180
+ self .precision_plugin = self .training_type_plugin ._precision_plugin
180
181
181
182
self ._check_tpu_mis_config ()
182
183
@@ -391,8 +392,7 @@ def handle_given_plugins(self) -> None:
391
392
def accelerator_types (self ) -> List [str ]:
392
393
return ["auto" ] + list (_AcceleratorType )
393
394
394
- @property
395
- def precision_plugin (self ) -> PrecisionPlugin :
395
+ def final_precision_plugin (self ) -> PrecisionPlugin :
396
396
if self ._precision_plugin is None :
397
397
self ._precision_plugin = self .select_precision_plugin ()
398
398
return self ._precision_plugin
@@ -408,16 +408,28 @@ def final_training_type_plugin(self) -> TrainingTypePlugin:
408
408
if self ._checkpoint_io is not None :
409
409
self ._training_type_plugin .checkpoint_io = self ._checkpoint_io
410
410
if (
411
- (hasattr (self .strategy , "precision_plugin" ) and self .precision_plugin is None )
412
- or not hasattr (self .strategy , "precision_plugin" )
411
+ # handle custom strategy with custom precision
412
+ (
413
+ isinstance (self .strategy , TrainingTypePlugin ) and (
414
+ self .strategy .precision_plugin is None
415
+ or not isinstance (self .strategy .precision_plugin , PrecisionPlugin )
416
+ )
417
+ )
418
+ or not isinstance (self .strategy , TrainingTypePlugin )
413
419
):
414
- precision_plugin = self .precision_plugin
420
+ precision_plugin = self .final_precision_plugin ()
415
421
if precision_plugin is not None :
416
422
self ._training_type_plugin ._precision_plugin = precision_plugin
417
423
self ._training_type_plugin_resolved = True
418
424
if (
419
- (hasattr (self .strategy , "accelerator" ) and self .strategy .accelerator is None )
420
- or not hasattr (self .strategy , "accelerator" )
425
+ # handle custom strategy with custom accelerator
426
+ (
427
+ isinstance (self .strategy , TrainingTypePlugin ) and (
428
+ self .strategy .accelerator is None
429
+ or not isinstance (self .strategy .accelerator , Accelerator )
430
+ )
431
+ )
432
+ or not isinstance (self .strategy , TrainingTypePlugin )
421
433
):
422
434
self ._training_type_plugin .accelerator = self .select_accelerator ()
423
435
return self ._training_type_plugin
@@ -790,12 +802,6 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
790
802
def select_accelerator (self ) -> Accelerator :
791
803
if isinstance (self .distributed_backend , Accelerator ):
792
804
# custom accelerator from user
793
- if self ._precision_plugin is not None or self ._training_type_plugin is not None :
794
- # plugins also specified by user
795
- rank_zero_warn (
796
- "Specified `Precision` and `TrainingType` plugins will be ignored,"
797
- " since an `Accelerator` instance was provided."
798
- )
799
805
return self .distributed_backend
800
806
801
807
if self .use_gpu :
0 commit comments