Skip to content

Add flag to prohibit equality checks between non-overlapping checks #6370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 15, 2019
17 changes: 17 additions & 0 deletions docs/source/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,23 @@ of the above sections.
# 'items' now has type List[List[str]]
...

``--strict-equality``
By default, mypy allows always-false comparisons like ``42 == 'no'``.
Use this flag to prohibit such comparisons of non-overlapping types, and
similar identity and container checks:

.. code-block:: python

from typing import Text

text: Text
if b'some bytes' in text: # Error: non-overlapping check!
...
if text != b'other bytes': # Error: non-overlapping check!
...

assert text is not None # OK, this special case is allowed.

``--strict``
This flag mode enables all optional error checking flags. You can see the
list of flags enabled by strict mode in the full ``mypy --help`` output.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ Miscellaneous strictness flags
Allows variables to be redefined with an arbitrary type, as long as the redefinition
is in the same block and nesting level as the original definition.

``strict_equality`` (bool, default False)
Prohibit equality checks, identity checks, and container checks between
non-overlapping types.

Global-only options
*******************

Expand Down
19 changes: 19 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2997,6 +2997,25 @@ def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]:
nextmethod = 'next'
return iterator, echk.check_method_call_by_name(nextmethod, iterator, [], [], expr)[0]

def analyze_container_item_type(self, typ: Type) -> Optional[Type]:
"""Check if a type is a nominal container of a union of such.

Return the corresponding container item type.
"""
if isinstance(typ, UnionType):
types = [] # type: List[Type]
for item in typ.items:
c_type = self.analyze_container_item_type(item)
if c_type:
types.append(c_type)
return UnionType.make_union(types)
if isinstance(typ, Instance) and typ.type.has_base('typing.Container'):
supertype = self.named_type('typing.Container').type
super_instance = map_instance_to_supertype(typ, supertype)
assert len(super_instance.args) == 1
return super_instance.args[0]
return None

def analyze_index_variables(self, index: Expression, item_type: Type,
infer_lvalue_type: bool, context: Context) -> None:
"""Type check or infer for loop or list comprehension index vars."""
Expand Down
48 changes: 45 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from mypy import message_registry
from mypy.infer import infer_type_arguments, infer_function_type_arguments
from mypy import join
from mypy.meet import narrow_declared_type
from mypy.meet import narrow_declared_type, is_overlapping_types
from mypy.subtypes import (
is_subtype, is_proper_subtype, is_equivalent, find_member, non_method_protocol_members,
)
Expand Down Expand Up @@ -1914,6 +1914,11 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
_, method_type = self.check_method_call_by_name(
'__contains__', right_type, [left], [ARG_POS], e, local_errors)
sub_result = self.bool_type()
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(right_type)
if isinstance(right_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
pass
Expand All @@ -1929,16 +1934,29 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
self.named_type('builtins.function'))
if not is_subtype(left_type, itertype):
self.msg.unsupported_operand_types('in', left_type, right_type, e)
# Only show dangerous overlap if there are no other errors.
elif (not local_errors.is_errors() and cont_type and
self.dangerous_comparison(left_type, cont_type)):
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
else:
self.msg.add_errors(local_errors)
elif operator in nodes.op_methods:
method = self.get_operator_method(operator)
err_count = self.msg.errors.total_errors()
sub_result, method_type = self.check_op(method, left_type, right, e,
allow_reverse=True)
allow_reverse=True)
# Only show dangerous overlap if there are no other errors. See
# testCustomEqCheckStrictEquality for an example.
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
right_type = self.accept(right)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)

elif operator == 'is' or operator == 'is not':
self.accept(right) # validate the right operand
right_type = self.accept(right) # validate the right operand
sub_result = self.bool_type()
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, 'identity', e)
method_type = None
else:
raise RuntimeError('Unknown comparison operator {}'.format(operator))
Expand All @@ -1954,6 +1972,30 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
assert result is not None
return result

def dangerous_comparison(self, left: Type, right: Type) -> bool:
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.

Rules:
* X and None are overlapping even in strict-optional mode. This is to allow
'assert x is not None' for x defined as 'x = None # type: str' in class body
(otherwise mypy itself would have couple dozen errors because of this).
* Optional[X] and Optional[Y] are non-overlapping if X and Y are
non-overlapping, although technically None is overlap, it is most
likely an error.
* Any overlaps with everything, i.e. always safe.
* Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0
are errors. This is mostly needed for bytes vs unicode, and
int vs float are added just for consistency.
"""
if not self.chk.options.strict_equality:
return False
if isinstance(left, NoneTyp) or isinstance(right, NoneTyp):
return False
if isinstance(left, UnionType) and isinstance(right, UnionType):
left = remove_optional(left)
right = remove_optional(right)
return not is_overlapping_types(left, right, ignore_promotions=True)

def get_operator_method(self, op: str) -> str:
if op == '/' and self.chk.options.python_version[0] == 2:
# TODO also check for "from __future__ import division"
Expand Down
3 changes: 3 additions & 0 deletions mypy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def copy(self) -> 'Errors':
new.scope = self.scope
return new

def total_errors(self) -> int:
return sum(len(errs) for errs in self.error_info_map.values())

def set_ignore_prefix(self, prefix: str) -> None:
"""Set path prefix that will be removed from all paths."""
prefix = os.path.normpath(prefix)
Expand Down
5 changes: 5 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ def add_invertible_flag(flag: str,
help="Allow unconditional variable redefinition with a new type",
group=strictness_group)

add_invertible_flag('--strict-equality', default=False, strict_flag=False,
help="Prohibit equality, identity, and container checks for"
" non-overlapping types",
group=strictness_group)

incremental_group = parser.add_argument_group(
title='Incremental mode',
description="Adjust how mypy incrementally type checks and caches modules. "
Expand Down
21 changes: 20 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,28 @@ def is_none_typevar_overlap(t1: Type, t2: Type) -> bool:
# As before, we degrade into 'Instance' whenever possible.

if isinstance(left, TypeType) and isinstance(right, TypeType):
# TODO: Can Callable[[...], T] and Type[T] be partially overlapping?
return _is_overlapping_types(left.item, right.item)

def _type_object_overlap(left: Type, right: Type) -> bool:
"""Special cases for type object types overlaps."""
# TODO: these checks are a bit in gray area, adjust if they cause problems.
# 1. Type[C] vs Callable[..., C], where the latter is class object.
if isinstance(left, TypeType) and isinstance(right, CallableType) and right.is_type_obj():
return _is_overlapping_types(left.item, right.ret_type)
# 2. Type[C] vs Meta, where Meta is a metaclass for C.
if (isinstance(left, TypeType) and isinstance(left.item, Instance) and
isinstance(right, Instance)):
left_meta = left.item.type.metaclass_type
if left_meta is not None:
return _is_overlapping_types(left_meta, right)
# builtins.type (default metaclass) overlaps with all metaclasses
return right.type.has_base('builtins.type')
# 3. Callable[..., C] vs Meta is considered below, when we switch to fallbacks.
return False

if isinstance(left, TypeType) or isinstance(right, TypeType):
return _type_object_overlap(left, right) or _type_object_overlap(right, left)

if isinstance(left, CallableType) and isinstance(right, CallableType):
return is_callable_compatible(left, right,
is_compat=_is_overlapping_types,
Expand Down
7 changes: 7 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,13 @@ def incompatible_typevar_value(self,
.format(typevar_name, callable_name(callee) or 'function', self.format(typ)),
context)

def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None:
left_str = 'element' if kind == 'container' else 'left operand'
right_str = 'container item' if kind == 'container' else 'right operand'
message = 'Non-overlapping {} check ({} type: {}, {} type: {})'
left_typ, right_typ = self.format_distinctly(left, right)
self.fail(message.format(kind, left_str, left_typ, right_str, right_typ), ctx)

def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None:
self.fail(
'Overload does not consistently use the "@{}" '.format(decorator)
Expand Down
5 changes: 5 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class BuildType:
# Please keep this list sorted
"allow_untyped_globals",
"allow_redefinition",
"strict_equality",
"always_false",
"always_true",
"check_untyped_defs",
Expand Down Expand Up @@ -157,6 +158,10 @@ def __init__(self) -> None:
# and the same nesting level as the initialization
self.allow_redefinition = False

# Prohibit equality, identity, and container checks for non-overlapping types.
# This makes 1 == '1', 1 in ['1'], and 1 is '1' errors.
self.strict_equality = False

# Variable names considered True
self.always_true = [] # type: List[str]

Expand Down
Loading