@@ -701,7 +701,7 @@ def is_singleton_type(typ: Type) -> bool:
701
701
)
702
702
703
703
704
- def try_expanding_enum_to_union (typ : Type , target_fullname : str ) -> ProperType :
704
+ def try_expanding_sum_type_to_union (typ : Type , target_fullname : str ) -> ProperType :
705
705
"""Attempts to recursively expand any enum Instances with the given target_fullname
706
706
into a Union of all of its component LiteralTypes.
707
707
@@ -723,28 +723,34 @@ class Status(Enum):
723
723
typ = get_proper_type (typ )
724
724
725
725
if isinstance (typ , UnionType ):
726
- items = [try_expanding_enum_to_union (item , target_fullname ) for item in typ .items ]
726
+ items = [try_expanding_sum_type_to_union (item , target_fullname ) for item in typ .items ]
727
727
return make_simplified_union (items , contract_literals = False )
728
- elif isinstance (typ , Instance ) and typ .type .is_enum and typ .type .fullname == target_fullname :
729
- new_items = []
730
- for name , symbol in typ .type .names .items ():
731
- if not isinstance (symbol .node , Var ):
732
- continue
733
- # Skip "_order_" and "__order__", since Enum will remove it
734
- if name in ("_order_" , "__order__" ):
735
- continue
736
- new_items .append (LiteralType (name , typ ))
737
- # SymbolTables are really just dicts, and dicts are guaranteed to preserve
738
- # insertion order only starting with Python 3.7. So, we sort these for older
739
- # versions of Python to help make tests deterministic.
740
- #
741
- # We could probably skip the sort for Python 3.6 since people probably run mypy
742
- # only using CPython, but we might as well for the sake of full correctness.
743
- if sys .version_info < (3 , 7 ):
744
- new_items .sort (key = lambda lit : lit .value )
745
- return make_simplified_union (new_items , contract_literals = False )
746
- else :
747
- return typ
728
+ elif isinstance (typ , Instance ) and typ .type .fullname == target_fullname :
729
+ if typ .type .is_enum :
730
+ new_items = []
731
+ for name , symbol in typ .type .names .items ():
732
+ if not isinstance (symbol .node , Var ):
733
+ continue
734
+ # Skip "_order_" and "__order__", since Enum will remove it
735
+ if name in ("_order_" , "__order__" ):
736
+ continue
737
+ new_items .append (LiteralType (name , typ ))
738
+ # SymbolTables are really just dicts, and dicts are guaranteed to preserve
739
+ # insertion order only starting with Python 3.7. So, we sort these for older
740
+ # versions of Python to help make tests deterministic.
741
+ #
742
+ # We could probably skip the sort for Python 3.6 since people probably run mypy
743
+ # only using CPython, but we might as well for the sake of full correctness.
744
+ if sys .version_info < (3 , 7 ):
745
+ new_items .sort (key = lambda lit : lit .value )
746
+ return make_simplified_union (new_items , contract_literals = False )
747
+ elif typ .type .fullname == "builtins.bool" :
748
+ return make_simplified_union (
749
+ [LiteralType (True , typ ), LiteralType (False , typ )],
750
+ contract_literals = False
751
+ )
752
+
753
+ return typ
748
754
749
755
750
756
def try_contracting_literals_in_union (types : Sequence [Type ]) -> List [ProperType ]:
@@ -762,9 +768,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]
762
768
for idx , typ in enumerate (proper_types ):
763
769
if isinstance (typ , LiteralType ):
764
770
fullname = typ .fallback .type .fullname
765
- if typ .fallback .type .is_enum :
771
+ if typ .fallback .type .is_enum or isinstance ( typ . value , bool ) :
766
772
if fullname not in sum_types :
767
- sum_types [fullname ] = (set (get_enum_values (typ .fallback )), [])
773
+ sum_types [fullname ] = (set (get_enum_values (typ .fallback ))
774
+ if typ .fallback .type .is_enum
775
+ else set ((True , False )),
776
+ [])
768
777
literals , indexes = sum_types [fullname ]
769
778
literals .discard (typ .value )
770
779
indexes .append (idx )
0 commit comments