50
50
SaveConfigCallback ,
51
51
)
52
52
from pytorch_lightning .utilities .exceptions import MisconfigurationException
53
- from pytorch_lightning .utilities .imports import _TORCHVISION_AVAILABLE
53
+ from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_8 , _TORCHVISION_AVAILABLE
54
54
from tests .helpers import BoringDataModule , BoringModel
55
55
from tests .helpers .runif import RunIf
56
56
from tests .helpers .utils import no_warning_call
@@ -576,21 +576,17 @@ def on_fit_start(self):
576
576
raise MisconfigurationException ("Error on fit start" )
577
577
578
578
579
+ @RunIf (skip_windows = True )
579
580
@pytest .mark .parametrize ("logger" , (False , True ))
580
- @pytest .mark .parametrize (
581
- "trainer_kwargs" ,
582
- (
583
- # dict(strategy="ddp_spawn")
584
- # dict(strategy="ddp")
585
- # the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
586
- # TODO revisit this test as it never worked with DDP or DDPSpawn
587
- dict (strategy = "single_device" ),
588
- pytest .param ({"tpu_cores" : 1 }, marks = RunIf (tpu = True )),
589
- ),
590
- )
591
- def test_cli_distributed_save_config_callback (tmpdir , logger , trainer_kwargs ):
581
+ @pytest .mark .parametrize ("strategy" , ("ddp_spawn" , "ddp" ))
582
+ def test_cli_distributed_save_config_callback (tmpdir , logger , strategy ):
583
+ if _TORCH_GREATER_EQUAL_1_8 :
584
+ from torch .multiprocessing import ProcessRaisedException
585
+ else :
586
+ ProcessRaisedException = Exception
587
+
592
588
with mock .patch ("sys.argv" , ["any.py" , "fit" ]), pytest .raises (
593
- MisconfigurationException , match = r"Error on fit start"
589
+ ( MisconfigurationException , ProcessRaisedException ) , match = r"Error on fit start"
594
590
):
595
591
LightningCLI (
596
592
EarlyExitTestModel ,
@@ -599,7 +595,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
599
595
"logger" : logger ,
600
596
"max_steps" : 1 ,
601
597
"max_epochs" : 1 ,
602
- ** trainer_kwargs ,
598
+ "strategy" : strategy ,
599
+ "accelerator" : "auto" ,
600
+ "devices" : 1 ,
603
601
},
604
602
)
605
603
if logger :
0 commit comments