Skip to content

Commit aacd131

Browse files
authored
Fix mypy in utilities.distributed (#8201)
1 parent efec3d4 commit aacd131

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ module = [
6767
"pytorch_lightning.utilities.cloud_io",
6868
"pytorch_lightning.utilities.device_dtype_mixin",
6969
"pytorch_lightning.utilities.device_parser",
70+
"pytorch_lightning.utilities.distributed",
7071
"pytorch_lightning.utilities.parsing",
7172
]
7273
ignore_errors = "False"

pytorch_lightning/utilities/distributed.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from functools import wraps
1818
from platform import python_version
19-
from typing import Any, Optional, Union
19+
from typing import Any, Callable, List, Optional, Tuple, Type, Union
2020

2121
import torch
2222
from torch.nn.parallel.distributed import DistributedDataParallel
@@ -31,21 +31,22 @@
3131

3232
else:
3333

34-
class ReduceOp:
34+
class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153)
3535
SUM = None
3636

37-
class group:
37+
class group: # type: ignore
3838
WORLD = None
3939

4040

4141
log = logging.getLogger(__name__)
4242

4343

44-
def rank_zero_only(fn):
44+
def rank_zero_only(fn: Callable) -> Callable:
4545
@wraps(fn)
46-
def wrapped_fn(*args, **kwargs):
46+
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
4747
if rank_zero_only.rank == 0:
4848
return fn(*args, **kwargs)
49+
return None
4950

5051
return wrapped_fn
5152

@@ -64,7 +65,7 @@ def _get_rank() -> int:
6465
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
6566

6667

67-
def rank_zero_warn(*args, stacklevel: int = 5, **kwargs):
68+
def rank_zero_warn(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None:
6869
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
6970

7071
rank_zero_deprecation(
@@ -74,7 +75,7 @@ def rank_zero_warn(*args, stacklevel: int = 5, **kwargs):
7475
return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs)
7576

7677

77-
def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs):
78+
def rank_zero_deprecation(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None:
7879
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
7980

8081
rank_zero_deprecation(
@@ -84,29 +85,29 @@ def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs):
8485
return rank_zero_deprecation(*args, stacklevel=stacklevel, **kwargs)
8586

8687

87-
def _info(*args, stacklevel: int = 2, **kwargs):
88+
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
8889
if python_version() >= "3.8.0":
8990
kwargs["stacklevel"] = stacklevel
9091
log.info(*args, **kwargs)
9192

9293

93-
def _debug(*args, stacklevel: int = 2, **kwargs):
94+
def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
9495
if python_version() >= "3.8.0":
9596
kwargs["stacklevel"] = stacklevel
9697
log.debug(*args, **kwargs)
9798

9899

99100
@rank_zero_only
100-
def rank_zero_debug(*args, stacklevel: int = 4, **kwargs):
101+
def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
101102
_debug(*args, stacklevel=stacklevel, **kwargs)
102103

103104

104105
@rank_zero_only
105-
def rank_zero_info(*args, stacklevel: int = 4, **kwargs):
106+
def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
106107
_info(*args, stacklevel=stacklevel, **kwargs)
107108

108109

109-
def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None):
110+
def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
110111
"""
111112
Function to gather all tensors from several ddp processes onto a list that
112113
is broadcasted to all processes
@@ -141,7 +142,7 @@ def distributed_available() -> bool:
141142

142143

143144
def sync_ddp_if_available(
144-
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
145+
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
145146
) -> torch.Tensor:
146147
"""
147148
Function to reduce a tensor across worker processes during distributed training
@@ -160,7 +161,7 @@ def sync_ddp_if_available(
160161

161162

162163
def sync_ddp(
163-
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
164+
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
164165
) -> torch.Tensor:
165166
"""
166167
Function to reduce the tensors from several ddp processes to one master process
@@ -196,7 +197,11 @@ def sync_ddp(
196197

197198
class AllGatherGrad(torch.autograd.Function):
198199
@staticmethod
199-
def forward(ctx, tensor, group=group.WORLD):
200+
def forward(
201+
ctx: Any,
202+
tensor: torch.Tensor,
203+
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
204+
) -> torch.Tensor:
200205
ctx.group = group
201206

202207
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
@@ -207,7 +212,7 @@ def forward(ctx, tensor, group=group.WORLD):
207212
return gathered_tensor
208213

209214
@staticmethod
210-
def backward(ctx, *grad_output):
215+
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
211216
grad_output = torch.cat(grad_output)
212217

213218
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
@@ -216,7 +221,7 @@ def backward(ctx, *grad_output):
216221

217222

218223
def all_gather_ddp_if_available(
219-
tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False
224+
tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
220225
) -> torch.Tensor:
221226
"""
222227
Function to gather a tensor from several distributed processes
@@ -241,8 +246,8 @@ def all_gather_ddp_if_available(
241246
def register_ddp_comm_hook(
242247
model: DistributedDataParallel,
243248
ddp_comm_state: Optional[object] = None,
244-
ddp_comm_hook: Optional[callable] = None,
245-
ddp_comm_wrapper: Optional[callable] = None,
249+
ddp_comm_hook: Optional[Callable] = None,
250+
ddp_comm_wrapper: Optional[Callable] = None,
246251
) -> None:
247252
"""
248253
Function to register communication hook for DDP model
@@ -322,6 +327,9 @@ def register_ddp_comm_hook(
322327
return
323328
if ddp_comm_hook is None:
324329
return
330+
# inform mypy that ddp_comm_hook is callable
331+
ddp_comm_hook: Callable = ddp_comm_hook
332+
325333
if ddp_comm_wrapper is not None:
326334
if not _TORCH_GREATER_EQUAL_1_9:
327335
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")

0 commit comments

Comments
 (0)