Skip to content

Commit a610e04

Browse files
authored
Add typing for utilities/enums.py (#11298)
1 parent e9009d6 commit a610e04

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ module = [
8989
"pytorch_lightning.trainer.callback_hook",
9090
"pytorch_lightning.trainer.connectors.accelerator_connector",
9191
"pytorch_lightning.trainer.connectors.callback_connector",
92-
"pytorch_lightning.trainer.connectors.checkpoint_connector",
9392
"pytorch_lightning.trainer.connectors.data_connector",
9493
"pytorch_lightning.trainer.data_loading",
9594
"pytorch_lightning.trainer.optimizers",
@@ -102,7 +101,6 @@ module = [
102101
"pytorch_lightning.utilities.data",
103102
"pytorch_lightning.utilities.deepspeed",
104103
"pytorch_lightning.utilities.distributed",
105-
"pytorch_lightning.utilities.enums",
106104
"pytorch_lightning.utilities.fetching",
107105
"pytorch_lightning.utilities.memory",
108106
"pytorch_lightning.utilities.meta",

pytorch_lightning/utilities/enums.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Enumerated utilities."""
15+
from __future__ import annotations
16+
1517
import os
1618
from enum import Enum, EnumMeta
17-
from typing import Any, List, Optional, Union
19+
from typing import Any
1820

1921
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2022
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
@@ -24,14 +26,14 @@ class LightningEnum(str, Enum):
2426
"""Type of any enumerator with allowed comparison to string invariant to cases."""
2527

2628
@classmethod
27-
def from_str(cls, value: str) -> Optional["LightningEnum"]:
29+
def from_str(cls, value: str) -> LightningEnum | None:
2830
statuses = [status for status in dir(cls) if not status.startswith("_")]
2931
for st in statuses:
3032
if st.lower() == value.lower():
3133
return getattr(cls, st)
3234
return None
3335

34-
def __eq__(self, other: Union[str, Enum]) -> bool:
36+
def __eq__(self, other: object) -> bool:
3537
other = other.value if isinstance(other, Enum) else str(other)
3638
return self.value.lower() == other.lower()
3739

@@ -55,12 +57,12 @@ def __getattribute__(cls, name: str) -> Any:
5557
return obj
5658

5759
def __getitem__(cls, name: str) -> Any:
58-
member = super().__getitem__(name)
60+
member: _OnAccessEnumMeta = super().__getitem__(name)
5961
member.deprecate()
6062
return member
6163

62-
def __call__(cls, value: str, *args: Any, **kwargs: Any) -> Any:
63-
obj = super().__call__(value, *args, **kwargs)
64+
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
65+
obj = super().__call__(*args, **kwargs)
6466
if isinstance(obj, Enum):
6567
obj.deprecate()
6668
return obj
@@ -94,11 +96,11 @@ class PrecisionType(LightningEnum):
9496
MIXED = "mixed"
9597

9698
@staticmethod
97-
def supported_type(precision: Union[str, int]) -> bool:
99+
def supported_type(precision: str | int) -> bool:
98100
return any(x == precision for x in PrecisionType)
99101

100102
@staticmethod
101-
def supported_types() -> List[str]:
103+
def supported_types() -> list[str]:
102104
return [x.value for x in PrecisionType]
103105

104106

@@ -123,7 +125,7 @@ class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta):
123125
DDP_FULLY_SHARDED = "ddp_fully_sharded"
124126

125127
@staticmethod
126-
def interactive_compatible_types() -> List["DistributedType"]:
128+
def interactive_compatible_types() -> list[DistributedType]:
127129
"""Returns a list containing interactive compatible DistributeTypes."""
128130
return [
129131
DistributedType.DP,
@@ -181,7 +183,7 @@ def supported_type(val: str) -> bool:
181183
return any(x.value == val for x in GradClipAlgorithmType)
182184

183185
@staticmethod
184-
def supported_types() -> List[str]:
186+
def supported_types() -> list[str]:
185187
return [x.value for x in GradClipAlgorithmType]
186188

187189

@@ -219,7 +221,7 @@ def get_max_depth(mode: str) -> int:
219221
raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.")
220222

221223
@staticmethod
222-
def supported_types() -> List[str]:
224+
def supported_types() -> list[str]:
223225
return [x.value for x in ModelSummaryMode]
224226

225227

@@ -247,7 +249,7 @@ class _StrategyType(LightningEnum):
247249
DDP_FULLY_SHARDED = "ddp_fully_sharded"
248250

249251
@staticmethod
250-
def interactive_compatible_types() -> List["_StrategyType"]:
252+
def interactive_compatible_types() -> list[_StrategyType]:
251253
"""Returns a list containing interactive compatible _StrategyTypes."""
252254
return [
253255
_StrategyType.DP,
@@ -299,7 +301,7 @@ def is_manual(self) -> bool:
299301
return self is _FaultTolerantMode.MANUAL
300302

301303
@classmethod
302-
def detect_current_mode(cls) -> "_FaultTolerantMode":
304+
def detect_current_mode(cls) -> _FaultTolerantMode:
303305
"""This classmethod detects if `Fault Tolerant` is activated and maps its value to `_FaultTolerantMode`."""
304306
env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0").lower()
305307
# the int values are kept for backwards compatibility, but long-term we want to keep only the strings

0 commit comments

Comments
 (0)