Skip to content

Add support for conditionally defined overloads #10712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
99436ee
Add support for conditionally defined overloads
cdce8p Jun 25, 2021
7e7502b
Bugfix
cdce8p Jun 25, 2021
52ba893
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Dec 14, 2021
3effa15
Redo logic to support elif + else
cdce8p Dec 14, 2021
ee6ad3c
Fix typing issues
cdce8p Dec 14, 2021
bc486e0
Fix small issue with merging IfStmt
cdce8p Dec 15, 2021
a1370e0
Update existing tests
cdce8p Dec 15, 2021
0d2dee3
Add additional tests
cdce8p Dec 15, 2021
ae394cc
Fix check-functions tests
cdce8p Dec 15, 2021
1dd5679
Fix crash
cdce8p Dec 15, 2021
a8c3899
Remove redundant cast
cdce8p Dec 15, 2021
3a93f7e
Add last test cases
cdce8p Dec 15, 2021
9a4c703
Fix tests
cdce8p Dec 15, 2021
2777ce1
Typecheck skipped IfStmt conditions
cdce8p Dec 15, 2021
b36a5f8
More tests
cdce8p Dec 15, 2021
4c2e98b
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Dec 15, 2021
1452020
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Dec 16, 2021
dcc49b2
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Jan 10, 2022
f8e27b2
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Jan 18, 2022
14defb4
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Jan 30, 2022
d6cc690
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Feb 6, 2022
7cb5eac
Merge remote-tracking branch 'upstream/master' into conditional-overl…
cdce8p Feb 25, 2022
80b05d5
Apply suggestions from review
cdce8p Mar 2, 2022
3d0397a
Don't merge starting If blocks without overloads
cdce8p Mar 2, 2022
914d517
Emit error if condition can't be inferred
cdce8p Mar 2, 2022
4f79d43
Add additional test cases
cdce8p Mar 2, 2022
4776070
Add documentation
cdce8p Mar 2, 2022
cd1e9b0
Copyedits to the docs
JelleZijlstra Mar 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions docs/source/more_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,114 @@ with ``Union[int, slice]`` and ``Union[T, Sequence]``.
to returning ``Any`` only if the input arguments also contain ``Any``.


Conditional overloads
---------------------

Sometimes it is useful to define overloads conditionally.
Common use cases include types that are unavailable at runtime or that
only exist in a certain Python version. All existing overload rules still apply.
For example, there must be at least two overloads.

.. note::

Mypy can only infer a limited number of conditions.
Supported ones currently include :py:data:`~typing.TYPE_CHECKING`, ``MYPY``,
:ref:`version_and_platform_checks`, and :option:`--always-true <mypy --always-true>`
and :option:`--always-false <mypy --always-false>` values.

.. code-block:: python

from typing import TYPE_CHECKING, Any, overload

if TYPE_CHECKING:
class A: ...
class B: ...


if TYPE_CHECKING:
@overload
def func(var: A) -> A: ...

@overload
def func(var: B) -> B: ...

def func(var: Any) -> Any:
return var


reveal_type(func(A())) # Revealed type is "A"

.. code-block:: python

# flags: --python-version 3.10
import sys
from typing import Any, overload

class A: ...
class B: ...
class C: ...
class D: ...


if sys.version_info < (3, 7):
@overload
def func(var: A) -> A: ...

elif sys.version_info >= (3, 10):
@overload
def func(var: B) -> B: ...

else:
@overload
def func(var: C) -> C: ...

@overload
def func(var: D) -> D: ...

def func(var: Any) -> Any:
return var


reveal_type(func(B())) # Revealed type is "B"
reveal_type(func(C())) # No overload variant of "func" matches argument type "C"
# Possible overload variants:
# def func(var: B) -> B
# def func(var: D) -> D
# Revealed type is "Any"


.. note::

In the last example, mypy is executed with
:option:`--python-version 3.10 <mypy --python-version>`.
Therefore, the condition ``sys.version_info >= (3, 10)`` will match and
the overload for ``B`` will be added.
The overloads for ``A`` and ``C`` are ignored!
The overload for ``D`` is not defined conditionally and thus is also added.

When mypy cannot infer a condition to be always True or always False, an error is emitted.

.. code-block:: python

from typing import Any, overload

class A: ...
class B: ...


def g(bool_var: bool) -> None:
if bool_var: # Condition can't be inferred, unable to merge overloads
@overload
def func(var: A) -> A: ...

@overload
def func(var: B) -> B: ...

def func(var: Any) -> Any: ...

reveal_type(func(A())) # Revealed type is "Any"


.. _advanced_self:

Advanced uses of self-types
Expand Down
199 changes: 196 additions & 3 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from mypy import message_registry, errorcodes as codes
from mypy.errors import Errors
from mypy.options import Options
from mypy.reachability import mark_block_unreachable
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
from mypy.util import bytes_to_human_readable_repr

try:
Expand Down Expand Up @@ -344,9 +344,19 @@ def fail(self,
msg: str,
line: int,
column: int,
blocker: bool = True) -> None:
blocker: bool = True,
code: codes.ErrorCode = codes.SYNTAX) -> None:
if blocker or not self.options.ignore_errors:
self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX)
self.errors.report(line, column, msg, blocker=blocker, code=code)

def fail_merge_overload(self, node: IfStmt) -> None:
self.fail(
"Condition can't be inferred, unable to merge overloads",
line=node.line,
column=node.column,
blocker=False,
code=codes.MISC,
)

def visit(self, node: Optional[AST]) -> Any:
if node is None:
Expand Down Expand Up @@ -476,12 +486,93 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
ret: List[Statement] = []
current_overload: List[OverloadPart] = []
current_overload_name: Optional[str] = None
last_if_stmt: Optional[IfStmt] = None
last_if_overload: Optional[Union[Decorator, FuncDef, OverloadedFuncDef]] = None
last_if_stmt_overload_name: Optional[str] = None
last_if_unknown_truth_value: Optional[IfStmt] = None
skipped_if_stmts: List[IfStmt] = []
for stmt in stmts:
if_overload_name: Optional[str] = None
if_block_with_overload: Optional[Block] = None
if_unknown_truth_value: Optional[IfStmt] = None
if (
isinstance(stmt, IfStmt)
and len(stmt.body[0].body) == 1
and (
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
or current_overload_name is not None
and isinstance(stmt.body[0].body[0], FuncDef)
)
):
# Check IfStmt block to determine if function overloads can be merged
if_overload_name = self._check_ifstmt_for_overloads(stmt)
if if_overload_name is not None:
if_block_with_overload, if_unknown_truth_value = \
self._get_executable_if_block_with_overloads(stmt)

if (current_overload_name is not None
and isinstance(stmt, (Decorator, FuncDef))
and stmt.name == current_overload_name):
if last_if_stmt is not None:
skipped_if_stmts.append(last_if_stmt)
if last_if_overload is not None:
# Last stmt was an IfStmt with same overload name
# Add overloads to current_overload
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if last_if_unknown_truth_value:
self.fail_merge_overload(last_if_unknown_truth_value)
last_if_unknown_truth_value = None
current_overload.append(stmt)
elif (
current_overload_name is not None
and isinstance(stmt, IfStmt)
and if_overload_name == current_overload_name
):
# IfStmt only contains stmts relevant to current_overload.
# Check if stmts are reachable and add them to current_overload,
# otherwise skip IfStmt to allow subsequent overload
# or function definitions.
skipped_if_stmts.append(stmt)
if if_block_with_overload is None:
if if_unknown_truth_value is not None:
self.fail_merge_overload(if_unknown_truth_value)
continue
if last_if_overload is not None:
# Last stmt was an IfStmt with same overload name
# Add overloads to current_overload
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if isinstance(if_block_with_overload.body[0], OverloadedFuncDef):
current_overload.extend(if_block_with_overload.body[0].items)
else:
current_overload.append(
cast(Union[Decorator, FuncDef], if_block_with_overload.body[0])
)
else:
if last_if_stmt is not None:
ret.append(last_if_stmt)
last_if_stmt_overload_name = current_overload_name
last_if_stmt, last_if_overload = None, None
last_if_unknown_truth_value = None

if current_overload and current_overload_name == last_if_stmt_overload_name:
# Remove last stmt (IfStmt) from ret if the overload names matched
# Only happens if no executable block had been found in IfStmt
skipped_if_stmts.append(cast(IfStmt, ret.pop()))
if current_overload and skipped_if_stmts:
# Add bare IfStmt (without overloads) to ret
# Required for mypy to be able to still check conditions
for if_stmt in skipped_if_stmts:
self._strip_contents_from_if_stmt(if_stmt)
ret.append(if_stmt)
skipped_if_stmts = []
if len(current_overload) == 1:
ret.append(current_overload[0])
elif len(current_overload) > 1:
Expand All @@ -495,17 +586,119 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
current_overload = [stmt]
current_overload_name = stmt.name
elif (
isinstance(stmt, IfStmt)
and if_overload_name is not None
):
current_overload = []
current_overload_name = if_overload_name
last_if_stmt = stmt
last_if_stmt_overload_name = None
if if_block_with_overload is not None:
last_if_overload = cast(
Union[Decorator, FuncDef, OverloadedFuncDef],
if_block_with_overload.body[0]
)
last_if_unknown_truth_value = if_unknown_truth_value
else:
current_overload = []
current_overload_name = None
ret.append(stmt)

if current_overload and skipped_if_stmts:
# Add bare IfStmt (without overloads) to ret
# Required for mypy to be able to still check conditions
for if_stmt in skipped_if_stmts:
self._strip_contents_from_if_stmt(if_stmt)
ret.append(if_stmt)
if len(current_overload) == 1:
ret.append(current_overload[0])
elif len(current_overload) > 1:
ret.append(OverloadedFuncDef(current_overload))
elif last_if_stmt is not None:
ret.append(last_if_stmt)
return ret

def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
"""Check if IfStmt contains only overloads with the same name.
Return overload_name if found, None otherwise.
"""
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
# Multiple overloads have already been merged as OverloadedFuncDef.
if not (
len(stmt.body[0].body) == 1
and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
):
return None

overload_name = stmt.body[0].body[0].name
if stmt.else_body is None:
return overload_name

if isinstance(stmt.else_body, Block) and len(stmt.else_body.body) == 1:
# For elif: else_body contains an IfStmt itself -> do a recursive check.
if (
isinstance(stmt.else_body.body[0], (Decorator, FuncDef, OverloadedFuncDef))
and stmt.else_body.body[0].name == overload_name
):
return overload_name
if (
isinstance(stmt.else_body.body[0], IfStmt)
and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name
):
return overload_name

return None

def _get_executable_if_block_with_overloads(
self, stmt: IfStmt
) -> Tuple[Optional[Block], Optional[IfStmt]]:
"""Return block from IfStmt that will get executed.

Return
0 -> A block if sure that alternative blocks are unreachable.
1 -> An IfStmt if the reachability of it can't be inferred,
i.e. the truth value is unknown.
"""
infer_reachability_of_if_statement(stmt, self.options)
if (
stmt.else_body is None
and stmt.body[0].is_unreachable is True
):
# always False condition with no else
return None, None
if (
stmt.else_body is None
or stmt.body[0].is_unreachable is False
and stmt.else_body.is_unreachable is False
):
# The truth value is unknown, thus not conclusive
return None, stmt
if stmt.else_body.is_unreachable is True:
# else_body will be set unreachable if condition is always True
return stmt.body[0], None
if stmt.body[0].is_unreachable is True:
# body will be set unreachable if condition is always False
# else_body can contain an IfStmt itself (for elif) -> do a recursive check
if isinstance(stmt.else_body.body[0], IfStmt):
return self._get_executable_if_block_with_overloads(stmt.else_body.body[0])
return stmt.else_body, None
return None, stmt

def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
"""Remove contents from IfStmt.

Needed to still be able to check the conditions after the contents
have been merged with the surrounding function overloads.
"""
if len(stmt.body) == 1:
stmt.body[0].body = []
if stmt.else_body and len(stmt.else_body.body) == 1:
if isinstance(stmt.else_body.body[0], IfStmt):
self._strip_contents_from_if_stmt(stmt.else_body.body[0])
else:
stmt.else_body.body = []

def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ['C', 'F']

Expand Down
Loading