Skip to content

Commit 3b511bd

Browse files
authored
Add flag to prohibit equality checks between non-overlapping checks (#6370)
Fixes #1271 The new per-file flag `--strict-equality` disables equality, identity, and container checks between non-overlapping types. In general the implementation is straightforward. Here are some corner cases I have found: * `Any` should be always safe. * Type promotions should be ignored, `b'abc == 'abc'` should be an error. * Checks like `if x is not None: ...`, are special cased to not be errors if `x` has non-optional type. * `Optional[str]` and `Optional[bytes]` should be considered non-overlapping for the purpose of this flag. * For structural containers and custom `__eq__()` (i.e. incompatible with `object.__eq__()`) I suppress the non-overlapping types error, if there is already an error originating from the method call check. Note that I updated `typing-full.pyi` so that `Sequence` inherits from `Container` (this is needed by some added tests, and also matches real stubs). This however caused necessary changes in a bunch of builtins fixtures.
1 parent b8c78e6 commit 3b511bd

18 files changed

+346
-13
lines changed

docs/source/command_line.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,23 @@ of the above sections.
386386
# 'items' now has type List[List[str]]
387387
...
388388
389+
``--strict-equality``
390+
By default, mypy allows always-false comparisons like ``42 == 'no'``.
391+
Use this flag to prohibit such comparisons of non-overlapping types, and
392+
similar identity and container checks:
393+
394+
.. code-block:: python
395+
396+
from typing import Text
397+
398+
text: Text
399+
if b'some bytes' in text: # Error: non-overlapping check!
400+
...
401+
if text != b'other bytes': # Error: non-overlapping check!
402+
...
403+
404+
assert text is not None # OK, this special case is allowed.
405+
389406
``--strict``
390407
This flag mode enables all optional error checking flags. You can see the
391408
list of flags enabled by strict mode in the full ``mypy --help`` output.

docs/source/config_file.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ Miscellaneous strictness flags
294294
Allows variables to be redefined with an arbitrary type, as long as the redefinition
295295
is in the same block and nesting level as the original definition.
296296

297+
``strict_equality`` (bool, default False)
298+
Prohibit equality checks, identity checks, and container checks between
299+
non-overlapping types.
300+
297301
Global-only options
298302
*******************
299303

mypy/checker.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,6 +2997,25 @@ def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]:
29972997
nextmethod = 'next'
29982998
return iterator, echk.check_method_call_by_name(nextmethod, iterator, [], [], expr)[0]
29992999

3000+
def analyze_container_item_type(self, typ: Type) -> Optional[Type]:
3001+
"""Check if a type is a nominal container of a union of such.
3002+
3003+
Return the corresponding container item type.
3004+
"""
3005+
if isinstance(typ, UnionType):
3006+
types = [] # type: List[Type]
3007+
for item in typ.items:
3008+
c_type = self.analyze_container_item_type(item)
3009+
if c_type:
3010+
types.append(c_type)
3011+
return UnionType.make_union(types)
3012+
if isinstance(typ, Instance) and typ.type.has_base('typing.Container'):
3013+
supertype = self.named_type('typing.Container').type
3014+
super_instance = map_instance_to_supertype(typ, supertype)
3015+
assert len(super_instance.args) == 1
3016+
return super_instance.args[0]
3017+
return None
3018+
30003019
def analyze_index_variables(self, index: Expression, item_type: Type,
30013020
infer_lvalue_type: bool, context: Context) -> None:
30023021
"""Type check or infer for loop or list comprehension index vars."""

mypy/checkexpr.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from mypy import message_registry
4545
from mypy.infer import infer_type_arguments, infer_function_type_arguments
4646
from mypy import join
47-
from mypy.meet import narrow_declared_type
47+
from mypy.meet import narrow_declared_type, is_overlapping_types
4848
from mypy.subtypes import (
4949
is_subtype, is_proper_subtype, is_equivalent, find_member, non_method_protocol_members,
5050
)
@@ -1914,6 +1914,11 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
19141914
_, method_type = self.check_method_call_by_name(
19151915
'__contains__', right_type, [left], [ARG_POS], e, local_errors)
19161916
sub_result = self.bool_type()
1917+
# Container item type for strict type overlap checks. Note: we need to only
1918+
# check for nominal type, because a usual "Unsupported operands for in"
1919+
# will be reported for types incompatible with __contains__().
1920+
# See testCustomContainsCheckStrictEquality for an example.
1921+
cont_type = self.chk.analyze_container_item_type(right_type)
19171922
if isinstance(right_type, PartialType):
19181923
# We don't really know if this is an error or not, so just shut up.
19191924
pass
@@ -1929,16 +1934,29 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
19291934
self.named_type('builtins.function'))
19301935
if not is_subtype(left_type, itertype):
19311936
self.msg.unsupported_operand_types('in', left_type, right_type, e)
1937+
# Only show dangerous overlap if there are no other errors.
1938+
elif (not local_errors.is_errors() and cont_type and
1939+
self.dangerous_comparison(left_type, cont_type)):
1940+
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
19321941
else:
19331942
self.msg.add_errors(local_errors)
19341943
elif operator in nodes.op_methods:
19351944
method = self.get_operator_method(operator)
1945+
err_count = self.msg.errors.total_errors()
19361946
sub_result, method_type = self.check_op(method, left_type, right, e,
1937-
allow_reverse=True)
1947+
allow_reverse=True)
1948+
# Only show dangerous overlap if there are no other errors. See
1949+
# testCustomEqCheckStrictEquality for an example.
1950+
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
1951+
right_type = self.accept(right)
1952+
if self.dangerous_comparison(left_type, right_type):
1953+
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)
19381954

19391955
elif operator == 'is' or operator == 'is not':
1940-
self.accept(right) # validate the right operand
1956+
right_type = self.accept(right) # validate the right operand
19411957
sub_result = self.bool_type()
1958+
if self.dangerous_comparison(left_type, right_type):
1959+
self.msg.dangerous_comparison(left_type, right_type, 'identity', e)
19421960
method_type = None
19431961
else:
19441962
raise RuntimeError('Unknown comparison operator {}'.format(operator))
@@ -1954,6 +1972,30 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
19541972
assert result is not None
19551973
return result
19561974

1975+
def dangerous_comparison(self, left: Type, right: Type) -> bool:
1976+
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.
1977+
1978+
Rules:
1979+
* X and None are overlapping even in strict-optional mode. This is to allow
1980+
'assert x is not None' for x defined as 'x = None # type: str' in class body
1981+
(otherwise mypy itself would have couple dozen errors because of this).
1982+
* Optional[X] and Optional[Y] are non-overlapping if X and Y are
1983+
non-overlapping, although technically None is overlap, it is most
1984+
likely an error.
1985+
* Any overlaps with everything, i.e. always safe.
1986+
* Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0
1987+
are errors. This is mostly needed for bytes vs unicode, and
1988+
int vs float are added just for consistency.
1989+
"""
1990+
if not self.chk.options.strict_equality:
1991+
return False
1992+
if isinstance(left, NoneTyp) or isinstance(right, NoneTyp):
1993+
return False
1994+
if isinstance(left, UnionType) and isinstance(right, UnionType):
1995+
left = remove_optional(left)
1996+
right = remove_optional(right)
1997+
return not is_overlapping_types(left, right, ignore_promotions=True)
1998+
19571999
def get_operator_method(self, op: str) -> str:
19582000
if op == '/' and self.chk.options.python_version[0] == 2:
19592001
# TODO also check for "from __future__ import division"

mypy/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ def copy(self) -> 'Errors':
169169
new.scope = self.scope
170170
return new
171171

172+
def total_errors(self) -> int:
173+
return sum(len(errs) for errs in self.error_info_map.values())
174+
172175
def set_ignore_prefix(self, prefix: str) -> None:
173176
"""Set path prefix that will be removed from all paths."""
174177
prefix = os.path.normpath(prefix)

mypy/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,11 @@ def add_invertible_flag(flag: str,
527527
help="Allow unconditional variable redefinition with a new type",
528528
group=strictness_group)
529529

530+
add_invertible_flag('--strict-equality', default=False, strict_flag=False,
531+
help="Prohibit equality, identity, and container checks for"
532+
" non-overlapping types",
533+
group=strictness_group)
534+
530535
incremental_group = parser.add_argument_group(
531536
title='Incremental mode',
532537
description="Adjust how mypy incrementally type checks and caches modules. "

mypy/meet.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,28 @@ def is_none_typevar_overlap(t1: Type, t2: Type) -> bool:
221221
# As before, we degrade into 'Instance' whenever possible.
222222

223223
if isinstance(left, TypeType) and isinstance(right, TypeType):
224-
# TODO: Can Callable[[...], T] and Type[T] be partially overlapping?
225224
return _is_overlapping_types(left.item, right.item)
226225

226+
def _type_object_overlap(left: Type, right: Type) -> bool:
227+
"""Special cases for type object types overlaps."""
228+
# TODO: these checks are a bit in gray area, adjust if they cause problems.
229+
# 1. Type[C] vs Callable[..., C], where the latter is class object.
230+
if isinstance(left, TypeType) and isinstance(right, CallableType) and right.is_type_obj():
231+
return _is_overlapping_types(left.item, right.ret_type)
232+
# 2. Type[C] vs Meta, where Meta is a metaclass for C.
233+
if (isinstance(left, TypeType) and isinstance(left.item, Instance) and
234+
isinstance(right, Instance)):
235+
left_meta = left.item.type.metaclass_type
236+
if left_meta is not None:
237+
return _is_overlapping_types(left_meta, right)
238+
# builtins.type (default metaclass) overlaps with all metaclasses
239+
return right.type.has_base('builtins.type')
240+
# 3. Callable[..., C] vs Meta is considered below, when we switch to fallbacks.
241+
return False
242+
243+
if isinstance(left, TypeType) or isinstance(right, TypeType):
244+
return _type_object_overlap(left, right) or _type_object_overlap(right, left)
245+
227246
if isinstance(left, CallableType) and isinstance(right, CallableType):
228247
return is_callable_compatible(left, right,
229248
is_compat=_is_overlapping_types,

mypy/messages.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,13 @@ def incompatible_typevar_value(self,
975975
.format(typevar_name, callable_name(callee) or 'function', self.format(typ)),
976976
context)
977977

978+
def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None:
979+
left_str = 'element' if kind == 'container' else 'left operand'
980+
right_str = 'container item' if kind == 'container' else 'right operand'
981+
message = 'Non-overlapping {} check ({} type: {}, {} type: {})'
982+
left_typ, right_typ = self.format_distinctly(left, right)
983+
self.fail(message.format(kind, left_str, left_typ, right_str, right_typ), ctx)
984+
978985
def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None:
979986
self.fail(
980987
'Overload does not consistently use the "@{}" '.format(decorator)

mypy/options.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class BuildType:
2222
# Please keep this list sorted
2323
"allow_untyped_globals",
2424
"allow_redefinition",
25+
"strict_equality",
2526
"always_false",
2627
"always_true",
2728
"check_untyped_defs",
@@ -157,6 +158,10 @@ def __init__(self) -> None:
157158
# and the same nesting level as the initialization
158159
self.allow_redefinition = False
159160

161+
# Prohibit equality, identity, and container checks for non-overlapping types.
162+
# This makes 1 == '1', 1 in ['1'], and 1 is '1' errors.
163+
self.strict_equality = False
164+
160165
# Variable names considered True
161166
self.always_true = [] # type: List[str]
162167

0 commit comments

Comments
 (0)