21
21
from torch .utils .data import DataLoader
22
22
23
23
import tests .helpers .utils as tutils
24
- from pytorch_lightning import Trainer
24
+ from pytorch_lightning import LightningModule , Trainer
25
25
from pytorch_lightning .core .step_result import Result
26
26
from tests .helpers import BoringDataModule , BoringModel
27
27
from tests .helpers .runif import RunIf
@@ -36,24 +36,20 @@ def _setup_ddp(rank, worldsize):
36
36
dist .init_process_group ("gloo" , rank = rank , world_size = worldsize )
37
37
38
38
39
- def _ddp_test_fn (rank , worldsize , result_cls : Result ):
39
+ def _ddp_test_fn (rank , worldsize ):
40
40
_setup_ddp (rank , worldsize )
41
41
tensor = torch .tensor ([1.0 ])
42
-
43
- res = result_cls ()
44
- res .log ("test_tensor" , tensor , sync_dist = True , sync_dist_op = torch .distributed .ReduceOp .SUM )
45
-
46
- assert res ["test_tensor" ].item () == dist .get_world_size (), "Result-Log does not work properly with DDP and Tensors"
42
+ actual = LightningModule ._LightningModule__sync (tensor , sync_dist = True , sync_dist_op = torch .distributed .ReduceOp .SUM )
43
+ assert actual .item () == dist .get_world_size (), "Result-Log does not work properly with DDP and Tensors"
47
44
48
45
49
46
@RunIf (skip_windows = True )
50
47
def test_result_reduce_ddp ():
51
48
"""Make sure result logging works with DDP"""
52
49
tutils .reset_seed ()
53
50
tutils .set_random_master_port ()
54
-
55
51
worldsize = 2
56
- mp .spawn (_ddp_test_fn , args = (worldsize , Result ), nprocs = worldsize )
52
+ mp .spawn (_ddp_test_fn , args = (worldsize , ), nprocs = worldsize )
57
53
58
54
59
55
@pytest .mark .parametrize (
0 commit comments