Skip to content

Add flag to prohibit equality checks between non-overlapping checks #6370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 15, 2019
18 changes: 18 additions & 0 deletions docs/source/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,24 @@ of the above sections.
# 'items' now has type List[List[str]]
...

``--strict-equality``
By default, mypy allows always-false comparisons like ``42 == 'no'``.
Use this flag to prohibit such comparisons of non-overlapping types, and
similar identity and container checks:

.. code-block:: python

from typing import Text

text: Text
if b'some bytes' in text: # Error: non-overlapping check!
...
if text != b'other bytes': # Error: non-overlapping check!
...

if text is not None: # Error: non-overlapping check, 'text' can't be None.
...

``--strict``
This flag mode enables all optional error checking flags. You can see the
list of flags enabled by strict mode in the full ``mypy --help`` output.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ Miscellaneous strictness flags
Allows variables to be redefined with an arbitrary type, as long as the redefinition
is in the same block and nesting level as the original definition.

``strict_equality`` (bool, default False)
Prohibit equality checks, identity checks, and container checks between
non-overlapping types.

Global-only options
*******************

Expand Down
19 changes: 19 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2997,6 +2997,25 @@ def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]:
nextmethod = 'next'
return iterator, echk.check_method_call_by_name(nextmethod, iterator, [], [], expr)[0]

def analyze_container_item_type(self, typ: Type) -> Optional[Type]:
"""Check if a type is a nominal container of a union of such.

Return the corresponding container item type.
"""
if isinstance(typ, UnionType):
types = [] # type: List[Type]
for item in typ.items:
c_type = self.analyze_container_item_type(item)
if c_type:
types.append(c_type)
return UnionType.make_union(types)
if isinstance(typ, Instance) and typ.type.has_base('typing.Container'):
supertype = self.named_type('typing.Container').type
super_instance = map_instance_to_supertype(typ, supertype)
assert len(super_instance.args) == 1
return super_instance.args[0]
return None

def analyze_index_variables(self, index: Expression, item_type: Type,
infer_lvalue_type: bool, context: Context) -> None:
"""Type check or infer for loop or list comprehension index vars."""
Expand Down
44 changes: 41 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from mypy import message_registry
from mypy.infer import infer_type_arguments, infer_function_type_arguments
from mypy import join
from mypy.meet import narrow_declared_type
from mypy.meet import narrow_declared_type, is_overlapping_types
from mypy.subtypes import (
is_subtype, is_proper_subtype, is_equivalent, find_member, non_method_protocol_members,
)
Expand Down Expand Up @@ -1914,6 +1914,11 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
_, method_type = self.check_method_call_by_name(
'__contains__', right_type, [left], [ARG_POS], e, local_errors)
sub_result = self.bool_type()
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(right_type)
if isinstance(right_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
pass
Expand All @@ -1929,16 +1934,29 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
self.named_type('builtins.function'))
if not is_subtype(left_type, itertype):
self.msg.unsupported_operand_types('in', left_type, right_type, e)
# Only show dangerous overlap if there are no other errors.
elif (not local_errors.is_errors() and cont_type and
self.dangerous_comparison(left_type, cont_type)):
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
else:
self.msg.add_errors(local_errors)
elif operator in nodes.op_methods:
method = self.get_operator_method(operator)
err_count = self.msg.errors.total_errors()
sub_result, method_type = self.check_op(method, left_type, right, e,
allow_reverse=True)
allow_reverse=True)
# Only show dangerous overlap if there are no other errors. See
# testCustomEqCheckStrictEquality for an example.
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
right_type = self.accept(right)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)

elif operator == 'is' or operator == 'is not':
self.accept(right) # validate the right operand
right_type = self.accept(right) # validate the right operand
sub_result = self.bool_type()
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'identity', e)
method_type = None
else:
raise RuntimeError('Unknown comparison operator {}'.format(operator))
Expand All @@ -1954,6 +1972,26 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
assert result is not None
return result

def dangerous_comparison(self, left: Type, right: Type) -> bool:
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.

Rules:
* X and None are non-overlapping in strict-optional mode, and
overlapping otherwise.
* Optional[X] and Optional[Y] are non-overlapping if X and Y are
non-overlapping, although technically None is overlap, it is most
likely an error.
* Any overlaps with everything, i.e. always safe.
* Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0
are errors.
"""
if isinstance(left, UnionType) and isinstance(right, UnionType):
left = remove_optional(left)
right = remove_optional(right)
if self.chk.options.strict_equality:
return not is_overlapping_types(left, right, ignore_promotions=True)
return False

def get_operator_method(self, op: str) -> str:
if op == '/' and self.chk.options.python_version[0] == 2:
# TODO also check for "from __future__ import division"
Expand Down
3 changes: 3 additions & 0 deletions mypy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def copy(self) -> 'Errors':
new.scope = self.scope
return new

def total_errors(self) -> int:
return sum(len(errs) for errs in self.error_info_map.values())

def set_ignore_prefix(self, prefix: str) -> None:
"""Set path prefix that will be removed from all paths."""
prefix = os.path.normpath(prefix)
Expand Down
5 changes: 5 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ def add_invertible_flag(flag: str,
help="Allow unconditional variable redefinition with a new type",
group=strictness_group)

add_invertible_flag('--strict-equality', default=False, strict_flag=False,
help="Prohibit equality, identity, and container checks for"
" non-overlapping types",
group=strictness_group)

incremental_group = parser.add_argument_group(
title='Incremental mode',
description="Adjust how mypy incrementally type checks and caches modules. "
Expand Down
7 changes: 7 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,13 @@ def incompatible_typevar_value(self,
.format(typevar_name, callable_name(callee) or 'function', self.format(typ)),
context)

def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None:
left_str = 'element' if kind == 'container' else 'left operand'
right_str = 'container item' if kind == 'container' else 'right operand'
message = 'Non-overlapping {} check ({} type: {}, {} type: {})'
left_typ, right_typ = self.format_distinctly(left, right)
self.fail(message.format(kind, left_str, left_typ, right_str, right_typ), ctx)

def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None:
self.fail(
'Overload does not consistently use the "@{}" '.format(decorator)
Expand Down
5 changes: 5 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class BuildType:
# Please keep this list sorted
"allow_untyped_globals",
"allow_redefinition",
"strict_equality",
"always_false",
"always_true",
"check_untyped_defs",
Expand Down Expand Up @@ -157,6 +158,10 @@ def __init__(self) -> None:
# and the same nesting level as the initialization
self.allow_redefinition = False

# Prohibit equality, identity, and container checks for non-overlapping types.
# This makes 1 == '1', 1 in ['1'], and 1 is '1' errors.
self.strict_equality = False

# Variable names considered True
self.always_true = [] # type: List[str]

Expand Down
143 changes: 141 additions & 2 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ class C:
# type: ('int') -> bool
pass

[builtins_py2 fixtures/bool.pyi]
[builtins_py2 fixtures/bool_py2.pyi]

[case cmpIgnoredPy3]

Expand Down Expand Up @@ -604,7 +604,7 @@ class X:
class Y:
def __lt__(self, o: 'Y') -> A: pass
def __gt__(self, o: 'Y') -> A: pass
def __eq__(self, o: 'Y') -> B: pass
def __eq__(self, o: 'Y') -> B: pass # type: ignore
[builtins fixtures/bool.pyi]


Expand Down Expand Up @@ -1947,3 +1947,142 @@ a.__pow__() # E: Too few arguments for "__pow__" of "int"
x, y = [], [] # E: Need type annotation for 'x' \
# E: Need type annotation for 'y'
[builtins fixtures/list.pyi]

[case testStrictEqualityEq]
# flags: --strict-equality
class A: ...
class B: ...
class C(B): ...

A() == B() # E: Non-overlapping equality check (left operand type: "A", right operand type: "B")
B() == C()
C() == B()
A() != B() # E: Non-overlapping equality check (left operand type: "A", right operand type: "B")
B() != C()
C() != B()
[builtins fixtures/bool.pyi]

[case testStrictEqualityIs]
# flags: --strict-equality
class A: ...
class B: ...
class C(B): ...

A() is B() # E: Non-overlapping identity check (left operand type: "A", right operand type: "B")
B() is C()
C() is B()
A() is not B() # E: Non-overlapping identity check (left operand type: "A", right operand type: "B")
B() is not C()
C() is not B()
[builtins fixtures/bool.pyi]

[case testStrictEqualityContains]
# flags: --strict-equality
class A: ...
class B: ...
class C(B): ...

A() in [B()] # E: Non-overlapping container check (element type: "A", container item type: "B")
B() in [C()]
C() in [B()]
A() not in [B()] # E: Non-overlapping container check (element type: "A", container item type: "B")
B() not in [C()]
C() not in [B()]
[builtins fixtures/list.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityUnions]
# flags: --strict-equality
from typing import Container, Union

class A: ...
class B: ...

a: Union[int, str]
b: Union[A, B]

a == 42
b == 42 # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int")

a is 42
b is 42 # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int")

ca: Union[Container[int], Container[str]]
cb: Union[Container[A], Container[B]]

42 in ca
42 in cb # E: Non-overlapping container check (element type: "int", container item type: "Union[A, B]")
[builtins fixtures/bool.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityNoPromote]
# flags: --strict-equality
'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes")
b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str")

x: str
y: bytes
x != y # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes")
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityAny]
# flags: --strict-equality
from typing import Any, Container

x: Any
c: Container[str]
x in c
x == 42
x is 42
[builtins fixtures/bool.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityStrictOptional]
# flags: --strict-equality --strict-optional

x: str
if x is not None: # E: Non-overlapping identity check (left operand type: "str", right operand type: "None")
pass
[builtins fixtures/bool.pyi]

[case testStrictEqualityNoStrictOptional]
# flags: --strict-equality --no-strict-optional

x: str
if x is not None: # OK without strict-optional
pass
[builtins fixtures/bool.pyi]

[case testStrictEqualityEqNoOptionalOverlap]
# flags: --strict-equality --strict-optional
from typing import Optional

x: Optional[str]
y: Optional[int]
if x == y: # E: Non-overlapping equality check (left operand type: "Optional[str]", right operand type: "Optional[int]")
...
[builtins fixtures/bool.pyi]

[case testCustomEqCheckStrictEquality]
# flags: --strict-equality
class A:
def __eq__(self, other: A) -> bool: # type: ignore
...
class B:
def __eq__(self, other: B) -> bool: # type: ignore
...

# Don't report non-overlapping check if there is already and error.
A() == B() # E: Unsupported operand types for == ("A" and "B")
[builtins fixtures/bool.pyi]

[case testCustomContainsCheckStrictEquality]
# flags: --strict-equality
class A:
def __contains__(self, other: A) -> bool:
...

# Don't report non-overlapping check if there is already and error.
42 in A() # E: Unsupported operand types for in ("int" and "A")
[builtins fixtures/bool.pyi]
13 changes: 13 additions & 0 deletions test-data/unit/check-flags.test
Original file line number Diff line number Diff line change
Expand Up @@ -1107,3 +1107,16 @@ class A(Generic[T]):
def f(c: A) -> None: # E: Missing type parameters for generic type
pass
[out]

[case testStrictEqualityPerFile]
# flags: --config-file tmp/mypy.ini
import b
42 == 'no' # E: Non-overlapping equality check (left operand type: "int", right operand type: "str")
[file b.py]
42 == 'no'
[file mypy.ini]
[[mypy]
strict_equality = True
[[mypy-b]
strict_equality = False
[builtins fixtures/bool.pyi]
3 changes: 2 additions & 1 deletion test-data/unit/fixtures/async_await.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ U = typing.TypeVar('U')
class list(typing.Sequence[T]):
def __iter__(self) -> typing.Iterator[T]: ...
def __getitem__(self, i: int) -> T: ...
def __contains__(self, item: object) -> bool: ...

class object:
def __init__(self) -> None: pass
class type: pass
class function: pass
class int: pass
class str: pass
class bool: pass
class bool(int): pass
class dict(typing.Generic[T, U]): pass
class set(typing.Generic[T]): pass
class tuple(typing.Generic[T]): pass
Expand Down
Loading