Skip to content

Fix mypy in utilities.device_dtype_mixin #8127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
34 changes: 19 additions & 15 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union
from typing import Any, Optional, Union

import torch
from torch.nn import Module
Expand All @@ -21,17 +21,17 @@
class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ['device', 'dtype']

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._dtype = torch.get_default_dtype()
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device('cpu')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deoes device also need the torch.device type?


@property
def dtype(self) -> Union[str, torch.dtype]:
return self._dtype

@dtype.setter
def dtype(self, new_dtype: Union[str, torch.dtype]):
def dtype(self, new_dtype: Union[str, torch.dtype]) -> None:
# necessary to avoid infinite recursion
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).')

Expand All @@ -45,7 +45,7 @@ def device(self) -> Union[str, torch.device]:

return device

def to(self, *args, **kwargs) -> Module:
def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin':
"""Moves and/or casts the parameters and buffers.

This can be called as
Expand Down Expand Up @@ -108,7 +108,7 @@ def to(self, *args, **kwargs) -> Module:
self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module:
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> 'DeviceDtypeModuleMixin':
"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
Expand All @@ -121,11 +121,13 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module:
Returns:
Module: self
"""
property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device)
property_device = (
device if isinstance(device, torch.device) else torch.device('cuda', index=device) # type: ignore
) # mypy expects `device` for `index` to be int, while `Optional[int]` is okay => ignore typing for now
self.__update_properties(device=property_device)
return super().cuda(device=device)

def cpu(self) -> Module:
def cpu(self) -> 'DeviceDtypeModuleMixin':
"""Moves all model parameters and buffers to the CPU.

Returns:
Expand All @@ -134,7 +136,7 @@ def cpu(self) -> Module:
self.__update_properties(device=torch.device('cpu'))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> Module:
def type(self, dst_type: Union[str, torch.dtype]) -> 'DeviceDtypeModuleMixin':
"""Casts all parameters and buffers to :attr:`dst_type`.

Arguments:
Expand All @@ -146,16 +148,16 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Module:
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> Module:
"""Casts all floating point parameters and buffers to float datatype.
def float(self) -> 'DeviceDtypeModuleMixin':
"""Casts all floating point parameters and buffers to ``float`` datatype.

Returns:
Module: self
"""
self.__update_properties(dtype=torch.float)
return super().float()

def double(self) -> Module:
def double(self) -> 'DeviceDtypeModuleMixin':
"""Casts all floating point parameters and buffers to ``double`` datatype.

Returns:
Expand All @@ -164,7 +166,7 @@ def double(self) -> Module:
self.__update_properties(dtype=torch.double)
return super().double()

def half(self) -> Module:
def half(self) -> 'DeviceDtypeModuleMixin':
"""Casts all floating point parameters and buffers to ``half`` datatype.

Returns:
Expand All @@ -173,9 +175,11 @@ def half(self) -> Module:
self.__update_properties(dtype=torch.half)
return super().half()

def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:

def apply_fn(module):
def apply_fn(module: Union['DeviceDtypeModuleMixin', Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
return
if device is not None:
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ ignore_errors = True
ignore_errors = True
[mypy-pytorch_lightning.utilities.cli]
ignore_errors = False
[mypy-pytorch_lightning.utilities.device_dtype_mixin]
ignore_errors = False
[mypy-pytorch_lightning.utilities.device_parser]
ignore_errors = False
[mypy-pytorch_lightning.utilities.parsing]
Expand Down