Skip to content

Commit 8e5647d

Browse files
committed
Apply review's suggestions
1 parent 57d83b3 commit 8e5647d

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

pytorch_lightning/utilities/device_dtype_mixin.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def dtype(self) -> Union[str, torch.dtype]:
3131
return self._dtype
3232

3333
@dtype.setter
34-
def dtype(self, new_dtype: Union[str, torch.dtype]) -> RuntimeError:
34+
def dtype(self, new_dtype: Union[str, torch.dtype]) -> None:
3535
# necessary to avoid infinite recursion
3636
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).')
3737

@@ -121,12 +121,9 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> 'DeviceDtyp
121121
Returns:
122122
Module: self
123123
"""
124-
if isinstance(device, torch.device):
125-
property_device = device
126-
elif isinstance(device, int):
127-
property_device = torch.device('cuda', index=device)
128-
else:
129-
property_device = torch.device('cuda')
124+
property_device = (
125+
device if isinstance(device, torch.device) else torch.device('cuda', index=device) # type: ignore
126+
) # mypy expects `device` for `index` to be int, while `Optional[int]` is okay => ignore typing for now
130127
self.__update_properties(device=property_device)
131128
return super().cuda(device=device)
132129

0 commit comments

Comments
 (0)