@@ -275,6 +275,13 @@ def __set_name__(self, enum_class, member_name):
275
275
enum_member .__objclass__ = enum_class
276
276
enum_member .__init__ (* args )
277
277
enum_member ._sort_order_ = len (enum_class ._member_names_ )
278
+
279
+ if Flag is not None and issubclass (enum_class , Flag ):
280
+ enum_class ._flag_mask_ |= value
281
+ if _is_single_bit (value ):
282
+ enum_class ._singles_mask_ |= value
283
+ enum_class ._all_bits_ = 2 ** ((enum_class ._flag_mask_ ).bit_length ()) - 1
284
+
278
285
# If another member with the same value was already defined, the
279
286
# new member becomes an alias to the existing one.
280
287
try :
@@ -532,12 +539,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
532
539
classdict ['_use_args_' ] = use_args
533
540
#
534
541
# convert future enum members into temporary _proto_members
535
- # and record integer values in case this will be a Flag
536
- flag_mask = 0
537
542
for name in member_names :
538
543
value = classdict [name ]
539
- if isinstance (value , int ):
540
- flag_mask |= value
541
544
classdict [name ] = _proto_member (value )
542
545
#
543
546
# house-keeping structures
@@ -554,8 +557,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
554
557
boundary
555
558
or getattr (first_enum , '_boundary_' , None )
556
559
)
557
- classdict ['_flag_mask_' ] = flag_mask
558
- classdict ['_all_bits_' ] = 2 ** ((flag_mask ).bit_length ()) - 1
560
+ classdict ['_flag_mask_' ] = 0
561
+ classdict ['_singles_mask_' ] = 0
562
+ classdict ['_all_bits_' ] = 0
559
563
classdict ['_inverted_' ] = None
560
564
try :
561
565
exc = None
@@ -644,21 +648,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
644
648
):
645
649
delattr (enum_class , '_boundary_' )
646
650
delattr (enum_class , '_flag_mask_' )
651
+ delattr (enum_class , '_singles_mask_' )
647
652
delattr (enum_class , '_all_bits_' )
648
653
delattr (enum_class , '_inverted_' )
649
654
elif Flag is not None and issubclass (enum_class , Flag ):
650
- # ensure _all_bits_ is correct and there are no missing flags
651
- single_bit_total = 0
652
- multi_bit_total = 0
653
- for flag in enum_class ._member_map_ .values ():
654
- flag_value = flag ._value_
655
- if _is_single_bit (flag_value ):
656
- single_bit_total |= flag_value
657
- else :
658
- # multi-bit flags are considered aliases
659
- multi_bit_total |= flag_value
660
- enum_class ._flag_mask_ = single_bit_total
661
- #
662
655
# set correct __iter__
663
656
member_list = [m ._value_ for m in enum_class ]
664
657
if member_list != sorted (member_list ):
@@ -974,6 +967,7 @@ def _find_data_repr_(mcls, class_name, bases):
974
967
975
968
@classmethod
976
969
def _find_data_type_ (mcls , class_name , bases ):
970
+ # a datatype has a __new__ method, or a __dataclass_fields__ attribute
977
971
data_types = set ()
978
972
base_chain = set ()
979
973
for chain in bases :
@@ -986,7 +980,7 @@ def _find_data_type_(mcls, class_name, bases):
986
980
if base ._member_type_ is not object :
987
981
data_types .add (base ._member_type_ )
988
982
break
989
- elif '__new__' in base .__dict__ or '__init__ ' in base .__dict__ :
983
+ elif '__new__' in base .__dict__ or '__dataclass_fields__ ' in base .__dict__ :
990
984
data_types .add (candidate or base )
991
985
break
992
986
else :
@@ -1303,8 +1297,8 @@ def _reduce_ex_by_global_name(self, proto):
1303
1297
class FlagBoundary (StrEnum ):
1304
1298
"""
1305
1299
control how out of range values are handled
1306
- "strict" -> error is raised
1307
- "conform" -> extra bits are discarded [default for Flag]
1300
+ "strict" -> error is raised [default for Flag]
1301
+ "conform" -> extra bits are discarded
1308
1302
"eject" -> lose flag status
1309
1303
"keep" -> keep flag status and all bits [default for IntFlag]
1310
1304
"""
@@ -1315,7 +1309,7 @@ class FlagBoundary(StrEnum):
1315
1309
STRICT , CONFORM , EJECT , KEEP = FlagBoundary
1316
1310
1317
1311
1318
- class Flag (Enum , boundary = CONFORM ):
1312
+ class Flag (Enum , boundary = STRICT ):
1319
1313
"""
1320
1314
Support for flags
1321
1315
"""
@@ -1394,6 +1388,7 @@ def _missing_(cls, value):
1394
1388
# - value must not include any skipped flags (e.g. if bit 2 is not
1395
1389
# defined, then 0d10 is invalid)
1396
1390
flag_mask = cls ._flag_mask_
1391
+ singles_mask = cls ._singles_mask_
1397
1392
all_bits = cls ._all_bits_
1398
1393
neg_value = None
1399
1394
if (
@@ -1425,7 +1420,8 @@ def _missing_(cls, value):
1425
1420
value = all_bits + 1 + value
1426
1421
# get members and unknown
1427
1422
unknown = value & ~ flag_mask
1428
- member_value = value & flag_mask
1423
+ aliases = value & ~ singles_mask
1424
+ member_value = value & singles_mask
1429
1425
if unknown and cls ._boundary_ is not KEEP :
1430
1426
raise ValueError (
1431
1427
'%s(%r) --> unknown values %r [%s]'
@@ -1439,11 +1435,25 @@ def _missing_(cls, value):
1439
1435
pseudo_member = cls ._member_type_ .__new__ (cls , value )
1440
1436
if not hasattr (pseudo_member , '_value_' ):
1441
1437
pseudo_member ._value_ = value
1442
- if member_value :
1443
- pseudo_member ._name_ = '|' .join ([
1444
- m ._name_ for m in cls ._iter_member_ (member_value )
1445
- ])
1446
- if unknown :
1438
+ if member_value or aliases :
1439
+ members = []
1440
+ combined_value = 0
1441
+ for m in cls ._iter_member_ (member_value ):
1442
+ members .append (m )
1443
+ combined_value |= m ._value_
1444
+ if aliases :
1445
+ value = member_value | aliases
1446
+ for n , pm in cls ._member_map_ .items ():
1447
+ if pm not in members and pm ._value_ and pm ._value_ & value == pm ._value_ :
1448
+ members .append (pm )
1449
+ combined_value |= pm ._value_
1450
+ unknown = value ^ combined_value
1451
+ pseudo_member ._name_ = '|' .join ([m ._name_ for m in members ])
1452
+ if not combined_value :
1453
+ pseudo_member ._name_ = None
1454
+ elif unknown and cls ._boundary_ is STRICT :
1455
+ raise ValueError ('%r: no members with value %r' % (cls , unknown ))
1456
+ elif unknown :
1447
1457
pseudo_member ._name_ += '|%s' % cls ._numeric_repr_ (unknown )
1448
1458
else :
1449
1459
pseudo_member ._name_ = None
@@ -1675,6 +1685,7 @@ def convert_class(cls):
1675
1685
body ['_boundary_' ] = boundary or etype ._boundary_
1676
1686
body ['_flag_mask_' ] = None
1677
1687
body ['_all_bits_' ] = None
1688
+ body ['_singles_mask_' ] = None
1678
1689
body ['_inverted_' ] = None
1679
1690
body ['__or__' ] = Flag .__or__
1680
1691
body ['__xor__' ] = Flag .__xor__
@@ -1750,7 +1761,8 @@ def convert_class(cls):
1750
1761
else :
1751
1762
multi_bits |= value
1752
1763
gnv_last_values .append (value )
1753
- enum_class ._flag_mask_ = single_bits
1764
+ enum_class ._flag_mask_ = single_bits | multi_bits
1765
+ enum_class ._singles_mask_ = single_bits
1754
1766
enum_class ._all_bits_ = 2 ** ((single_bits | multi_bits ).bit_length ()) - 1
1755
1767
# set correct __iter__
1756
1768
member_list = [m ._value_ for m in enum_class ]
0 commit comments