12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
"""Enumerated utilities."""
15
+ from __future__ import annotations
16
+
15
17
import os
16
18
from enum import Enum , EnumMeta
17
- from typing import Any , List , Optional , Union
19
+ from typing import Any
18
20
19
21
from pytorch_lightning .utilities .exceptions import MisconfigurationException
20
22
from pytorch_lightning .utilities .warnings import rank_zero_deprecation
@@ -24,14 +26,14 @@ class LightningEnum(str, Enum):
24
26
"""Type of any enumerator with allowed comparison to string invariant to cases."""
25
27
26
28
@classmethod
27
- def from_str (cls , value : str ) -> Optional [ " LightningEnum" ] :
29
+ def from_str (cls , value : str ) -> LightningEnum | None :
28
30
statuses = [status for status in dir (cls ) if not status .startswith ("_" )]
29
31
for st in statuses :
30
32
if st .lower () == value .lower ():
31
33
return getattr (cls , st )
32
34
return None
33
35
34
- def __eq__ (self , other : Union [ str , Enum ] ) -> bool :
36
+ def __eq__ (self , other : object ) -> bool :
35
37
other = other .value if isinstance (other , Enum ) else str (other )
36
38
return self .value .lower () == other .lower ()
37
39
@@ -55,12 +57,12 @@ def __getattribute__(cls, name: str) -> Any:
55
57
return obj
56
58
57
59
def __getitem__ (cls , name : str ) -> Any :
58
- member = super ().__getitem__ (name )
60
+ member : _OnAccessEnumMeta = super ().__getitem__ (name )
59
61
member .deprecate ()
60
62
return member
61
63
62
- def __call__ (cls , value : str , * args : Any , ** kwargs : Any ) -> Any :
63
- obj = super ().__call__ (value , * args , ** kwargs )
64
+ def __call__ (cls , * args : Any , ** kwargs : Any ) -> Any :
65
+ obj = super ().__call__ (* args , ** kwargs )
64
66
if isinstance (obj , Enum ):
65
67
obj .deprecate ()
66
68
return obj
@@ -94,11 +96,11 @@ class PrecisionType(LightningEnum):
94
96
MIXED = "mixed"
95
97
96
98
@staticmethod
97
- def supported_type (precision : Union [ str , int ] ) -> bool :
99
+ def supported_type (precision : str | int ) -> bool :
98
100
return any (x == precision for x in PrecisionType )
99
101
100
102
@staticmethod
101
- def supported_types () -> List [str ]:
103
+ def supported_types () -> list [str ]:
102
104
return [x .value for x in PrecisionType ]
103
105
104
106
@@ -123,7 +125,7 @@ class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta):
123
125
DDP_FULLY_SHARDED = "ddp_fully_sharded"
124
126
125
127
@staticmethod
126
- def interactive_compatible_types () -> List [ " DistributedType" ]:
128
+ def interactive_compatible_types () -> list [ DistributedType ]:
127
129
"""Returns a list containing interactive compatible DistributeTypes."""
128
130
return [
129
131
DistributedType .DP ,
@@ -181,7 +183,7 @@ def supported_type(val: str) -> bool:
181
183
return any (x .value == val for x in GradClipAlgorithmType )
182
184
183
185
@staticmethod
184
- def supported_types () -> List [str ]:
186
+ def supported_types () -> list [str ]:
185
187
return [x .value for x in GradClipAlgorithmType ]
186
188
187
189
@@ -219,7 +221,7 @@ def get_max_depth(mode: str) -> int:
219
221
raise ValueError (f"`mode` can be { ', ' .join (list (ModelSummaryMode ))} , got { mode } ." )
220
222
221
223
@staticmethod
222
- def supported_types () -> List [str ]:
224
+ def supported_types () -> list [str ]:
223
225
return [x .value for x in ModelSummaryMode ]
224
226
225
227
@@ -247,7 +249,7 @@ class _StrategyType(LightningEnum):
247
249
DDP_FULLY_SHARDED = "ddp_fully_sharded"
248
250
249
251
@staticmethod
250
- def interactive_compatible_types () -> List [ " _StrategyType" ]:
252
+ def interactive_compatible_types () -> list [ _StrategyType ]:
251
253
"""Returns a list containing interactive compatible _StrategyTypes."""
252
254
return [
253
255
_StrategyType .DP ,
@@ -299,7 +301,7 @@ def is_manual(self) -> bool:
299
301
return self is _FaultTolerantMode .MANUAL
300
302
301
303
@classmethod
302
- def detect_current_mode (cls ) -> " _FaultTolerantMode" :
304
+ def detect_current_mode (cls ) -> _FaultTolerantMode :
303
305
"""This classmethod detects if `Fault Tolerant` is activated and maps its value to `_FaultTolerantMode`."""
304
306
env_value = os .getenv ("PL_FAULT_TOLERANT_TRAINING" , "0" ).lower ()
305
307
# the int values are kept for backwards compatibility, but long-term we want to keep only the strings
0 commit comments