@@ -5009,6 +5009,44 @@ def conditional_callable_type_map(
5009
5009
5010
5010
return None , {}
5011
5011
5012
+ def contains_operator_right_operand_type_map (
5013
+ self , item_type : Type , collection_type : Type
5014
+ ) -> tuple [Type , Type ]:
5015
+ """
5016
+ Deduces the type of the right operand of the `in` operator.
5017
+ For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
5018
+ """
5019
+ if_types , else_types = [collection_type ], [collection_type ]
5020
+ item_strs = try_getting_str_literals_from_type (item_type )
5021
+ if item_strs :
5022
+ if_types , else_types = self ._contains_string_right_operand_type_map (
5023
+ set (item_strs ), collection_type
5024
+ )
5025
+ return UnionType .make_union (if_types ), UnionType .make_union (else_types )
5026
+
5027
+ def _contains_string_right_operand_type_map (
5028
+ self , item_strs : set [str ], t : Type
5029
+ ) -> tuple [list [Type ], list [Type ]]:
5030
+ t = get_proper_type (t )
5031
+ if_types : list [Type ] = []
5032
+ else_types : list [Type ] = []
5033
+ if isinstance (t , TypedDictType ):
5034
+ if item_strs <= t .items .keys ():
5035
+ if_types .append (t )
5036
+ elif item_strs .isdisjoint (t .items .keys ()):
5037
+ else_types .append (t )
5038
+ else :
5039
+ if_types .append (t )
5040
+ else_types .append (t )
5041
+ elif isinstance (t , UnionType ):
5042
+ for union_item in t .items :
5043
+ a , b = self ._contains_string_right_operand_type_map (item_strs , union_item )
5044
+ if_types .extend (a )
5045
+ else_types .extend (b )
5046
+ else :
5047
+ if_types = else_types = [t ]
5048
+ return if_types , else_types
5049
+
5012
5050
def _is_truthy_type (self , t : ProperType ) -> bool :
5013
5051
return (
5014
5052
(
@@ -5316,28 +5354,39 @@ def has_no_custom_eq_checks(t: Type) -> bool:
5316
5354
elif operator in {"in" , "not in" }:
5317
5355
assert len (expr_indices ) == 2
5318
5356
left_index , right_index = expr_indices
5319
- if left_index not in narrowable_operand_index_to_hash :
5320
- continue
5321
-
5322
5357
item_type = operand_types [left_index ]
5323
5358
collection_type = operand_types [right_index ]
5324
5359
5325
- # We only try and narrow away 'None' for now
5326
- if not is_optional (item_type ):
5327
- continue
5360
+ if_map , else_map = {}, {}
5361
+
5362
+ if left_index in narrowable_operand_index_to_hash :
5363
+ # We only try and narrow away 'None' for now
5364
+ if is_optional (item_type ):
5365
+ collection_item_type = get_proper_type (
5366
+ builtin_item_type (collection_type )
5367
+ )
5368
+ if (
5369
+ collection_item_type is not None
5370
+ and not is_optional (collection_item_type )
5371
+ and not (
5372
+ isinstance (collection_item_type , Instance )
5373
+ and collection_item_type .type .fullname == "builtins.object"
5374
+ )
5375
+ and is_overlapping_erased_types (item_type , collection_item_type )
5376
+ ):
5377
+ if_map [operands [left_index ]] = remove_optional (item_type )
5378
+
5379
+ if right_index in narrowable_operand_index_to_hash :
5380
+ (
5381
+ right_if_type ,
5382
+ right_else_type ,
5383
+ ) = self .contains_operator_right_operand_type_map (
5384
+ item_type , collection_type
5385
+ )
5386
+ expr = operands [right_index ]
5387
+ if_map [expr ] = right_if_type
5388
+ else_map [expr ] = right_else_type
5328
5389
5329
- collection_item_type = get_proper_type (builtin_item_type (collection_type ))
5330
- if collection_item_type is None or is_optional (collection_item_type ):
5331
- continue
5332
- if (
5333
- isinstance (collection_item_type , Instance )
5334
- and collection_item_type .type .fullname == "builtins.object"
5335
- ):
5336
- continue
5337
- if is_overlapping_erased_types (item_type , collection_item_type ):
5338
- if_map , else_map = {operands [left_index ]: remove_optional (item_type )}, {}
5339
- else :
5340
- continue
5341
5390
else :
5342
5391
if_map = {}
5343
5392
else_map = {}
@@ -5390,6 +5439,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
5390
5439
or_conditional_maps (left_if_vars , right_if_vars ),
5391
5440
and_conditional_maps (left_else_vars , right_else_vars ),
5392
5441
)
5442
+ elif isinstance (node , OpExpr ) and node .op == "in" :
5443
+ left_if_vars , left_else_vars = self .find_isinstance_check (node .left )
5444
+ right_if_vars , right_else_vars = self .find_isinstance_check (node .right )
5445
+
5393
5446
elif isinstance (node , UnaryExpr ) and node .op == "not" :
5394
5447
left , right = self .find_isinstance_check (node .expr )
5395
5448
return right , left
0 commit comments