16
16
import os
17
17
from functools import wraps
18
18
from platform import python_version
19
- from typing import Any , Optional , Union
19
+ from typing import Any , Callable , List , Optional , Tuple , Type , Union
20
20
21
21
import torch
22
22
from torch .nn .parallel .distributed import DistributedDataParallel
31
31
32
32
else :
33
33
34
- class ReduceOp :
34
+ class ReduceOp : # type: ignore # (see https://github.com/python/mypy/issues/1153)
35
35
SUM = None
36
36
37
- class group :
37
+ class group : # type: ignore
38
38
WORLD = None
39
39
40
40
41
41
log = logging .getLogger (__name__ )
42
42
43
43
44
- def rank_zero_only (fn ) :
44
+ def rank_zero_only (fn : Callable ) -> Callable :
45
45
@wraps (fn )
46
- def wrapped_fn (* args , ** kwargs ) :
46
+ def wrapped_fn (* args : Any , ** kwargs : Any ) -> Optional [ Any ] :
47
47
if rank_zero_only .rank == 0 :
48
48
return fn (* args , ** kwargs )
49
+ return None
49
50
50
51
return wrapped_fn
51
52
@@ -64,7 +65,7 @@ def _get_rank() -> int:
64
65
rank_zero_only .rank = getattr (rank_zero_only , "rank" , _get_rank ())
65
66
66
67
67
- def rank_zero_warn (* args , stacklevel : int = 5 , ** kwargs ) :
68
+ def rank_zero_warn (* args : Any , stacklevel : int = 5 , ** kwargs : Any ) -> None :
68
69
from pytorch_lightning .utilities .warnings import rank_zero_deprecation , rank_zero_warn
69
70
70
71
rank_zero_deprecation (
@@ -74,7 +75,7 @@ def rank_zero_warn(*args, stacklevel: int = 5, **kwargs):
74
75
return rank_zero_warn (* args , stacklevel = stacklevel , ** kwargs )
75
76
76
77
77
- def rank_zero_deprecation (* args , stacklevel : int = 5 , ** kwargs ) :
78
+ def rank_zero_deprecation (* args : Any , stacklevel : int = 5 , ** kwargs : Any ) -> None :
78
79
from pytorch_lightning .utilities .warnings import rank_zero_deprecation
79
80
80
81
rank_zero_deprecation (
@@ -84,29 +85,29 @@ def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs):
84
85
return rank_zero_deprecation (* args , stacklevel = stacklevel , ** kwargs )
85
86
86
87
87
- def _info (* args , stacklevel : int = 2 , ** kwargs ) :
88
+ def _info (* args : Any , stacklevel : int = 2 , ** kwargs : Any ) -> None :
88
89
if python_version () >= "3.8.0" :
89
90
kwargs ["stacklevel" ] = stacklevel
90
91
log .info (* args , ** kwargs )
91
92
92
93
93
- def _debug (* args , stacklevel : int = 2 , ** kwargs ) :
94
+ def _debug (* args : Any , stacklevel : int = 2 , ** kwargs : Any ) -> None :
94
95
if python_version () >= "3.8.0" :
95
96
kwargs ["stacklevel" ] = stacklevel
96
97
log .debug (* args , ** kwargs )
97
98
98
99
99
100
@rank_zero_only
100
- def rank_zero_debug (* args , stacklevel : int = 4 , ** kwargs ) :
101
+ def rank_zero_debug (* args : Any , stacklevel : int = 4 , ** kwargs : Any ) -> None :
101
102
_debug (* args , stacklevel = stacklevel , ** kwargs )
102
103
103
104
104
105
@rank_zero_only
105
- def rank_zero_info (* args , stacklevel : int = 4 , ** kwargs ) :
106
+ def rank_zero_info (* args : Any , stacklevel : int = 4 , ** kwargs : Any ) -> None :
106
107
_info (* args , stacklevel = stacklevel , ** kwargs )
107
108
108
109
109
- def gather_all_tensors (result : Union [ torch .Tensor ] , group : Optional [Any ] = None ):
110
+ def gather_all_tensors (result : torch .Tensor , group : Optional [Any ] = None ) -> List [ torch . Tensor ] :
110
111
"""
111
112
Function to gather all tensors from several ddp processes onto a list that
112
113
is broadcasted to all processes
@@ -141,7 +142,7 @@ def distributed_available() -> bool:
141
142
142
143
143
144
def sync_ddp_if_available (
144
- result : Union [ torch .Tensor ] , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
145
+ result : torch .Tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
145
146
) -> torch .Tensor :
146
147
"""
147
148
Function to reduce a tensor across worker processes during distributed training
@@ -160,7 +161,7 @@ def sync_ddp_if_available(
160
161
161
162
162
163
def sync_ddp (
163
- result : Union [ torch .Tensor ] , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
164
+ result : torch .Tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
164
165
) -> torch .Tensor :
165
166
"""
166
167
Function to reduce the tensors from several ddp processes to one master process
@@ -196,7 +197,11 @@ def sync_ddp(
196
197
197
198
class AllGatherGrad (torch .autograd .Function ):
198
199
@staticmethod
199
- def forward (ctx , tensor , group = group .WORLD ):
200
+ def forward (
201
+ ctx : Any ,
202
+ tensor : torch .Tensor ,
203
+ group : Optional ["torch.distributed.ProcessGroup" ] = group .WORLD ,
204
+ ) -> torch .Tensor :
200
205
ctx .group = group
201
206
202
207
gathered_tensor = [torch .zeros_like (tensor ) for _ in range (torch .distributed .get_world_size ())]
@@ -207,7 +212,7 @@ def forward(ctx, tensor, group=group.WORLD):
207
212
return gathered_tensor
208
213
209
214
@staticmethod
210
- def backward (ctx , * grad_output ) :
215
+ def backward (ctx : Any , * grad_output : torch . Tensor ) -> Tuple [ torch . Tensor , None ] :
211
216
grad_output = torch .cat (grad_output )
212
217
213
218
torch .distributed .all_reduce (grad_output , op = torch .distributed .ReduceOp .SUM , async_op = False , group = ctx .group )
@@ -216,7 +221,7 @@ def backward(ctx, *grad_output):
216
221
217
222
218
223
def all_gather_ddp_if_available (
219
- tensor : Union [ torch .Tensor ] , group : Optional [Any ] = None , sync_grads : bool = False
224
+ tensor : torch .Tensor , group : Optional ["torch.distributed.ProcessGroup" ] = None , sync_grads : bool = False
220
225
) -> torch .Tensor :
221
226
"""
222
227
Function to gather a tensor from several distributed processes
@@ -241,8 +246,8 @@ def all_gather_ddp_if_available(
241
246
def register_ddp_comm_hook (
242
247
model : DistributedDataParallel ,
243
248
ddp_comm_state : Optional [object ] = None ,
244
- ddp_comm_hook : Optional [callable ] = None ,
245
- ddp_comm_wrapper : Optional [callable ] = None ,
249
+ ddp_comm_hook : Optional [Callable ] = None ,
250
+ ddp_comm_wrapper : Optional [Callable ] = None ,
246
251
) -> None :
247
252
"""
248
253
Function to register communication hook for DDP model
@@ -322,6 +327,9 @@ def register_ddp_comm_hook(
322
327
return
323
328
if ddp_comm_hook is None :
324
329
return
330
+ # inform mypy that ddp_comm_hook is callable
331
+ ddp_comm_hook : Callable = ddp_comm_hook
332
+
325
333
if ddp_comm_wrapper is not None :
326
334
if not _TORCH_GREATER_EQUAL_1_9 :
327
335
rank_zero_warn ("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0." )
0 commit comments