Skip to content

Commit 3df1b9e

Browse files
committed
Add support for conditionally defined overloads
1 parent cdb2685 commit 3df1b9e

File tree

2 files changed

+209
-1
lines changed

2 files changed

+209
-1
lines changed

mypy/fastparse.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from mypy import message_registry, errorcodes as codes
3838
from mypy.errors import Errors
3939
from mypy.options import Options
40-
from mypy.reachability import mark_block_unreachable
40+
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
4141

4242
try:
4343
# pull this into a final variable to make mypyc be quiet about the
@@ -444,12 +444,50 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
444444
ret: List[Statement] = []
445445
current_overload: List[OverloadPart] = []
446446
current_overload_name: Optional[str] = None
447+
last_if_stmt: Optional[IfStmt] = None
448+
last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None
447449
for stmt in stmts:
448450
if (current_overload_name is not None
449451
and isinstance(stmt, (Decorator, FuncDef))
450452
and stmt.name == current_overload_name):
453+
if last_if_overload is not None:
454+
if isinstance(last_if_overload, OverloadedFuncDef):
455+
current_overload.extend(last_if_overload.items)
456+
else:
457+
current_overload.append(last_if_overload)
458+
last_if_stmt, last_if_overload = None, None
451459
current_overload.append(stmt)
460+
elif (
461+
current_overload_name is not None
462+
and isinstance(stmt, IfStmt)
463+
and len(stmt.body[0].body) == 1
464+
and isinstance(
465+
stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
466+
and stmt.body[0].body[0].name == current_overload_name
467+
):
468+
# IfStmt only contains stmts relevant to current_overload.
469+
# Check if stmts are reachable and add them to current_overload,
470+
# otherwise skip IfStmt to allow subsequent overload
471+
# or function definitions.
472+
infer_reachability_of_if_statement(stmt, self.options)
473+
if stmt.body[0].is_unreachable is True:
474+
continue
475+
if last_if_overload is not None:
476+
if isinstance(last_if_overload, OverloadedFuncDef):
477+
current_overload.extend(last_if_overload.items)
478+
else:
479+
current_overload.append(last_if_overload)
480+
last_if_stmt, last_if_overload = None, None
481+
last_if_overload = None
482+
if isinstance(stmt.body[0].body[0], OverloadedFuncDef):
483+
current_overload.extend(stmt.body[0].body[0].items)
484+
else:
485+
current_overload.append(stmt.body[0].body[0])
452486
else:
487+
if last_if_stmt is not None:
488+
ret.append(last_if_stmt)
489+
last_if_stmt, last_if_overload = None, None
490+
453491
if len(current_overload) == 1:
454492
ret.append(current_overload[0])
455493
elif len(current_overload) > 1:
@@ -458,6 +496,19 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
458496
if isinstance(stmt, Decorator):
459497
current_overload = [stmt]
460498
current_overload_name = stmt.name
499+
elif (
500+
isinstance(stmt, IfStmt)
501+
and len(stmt.body[0].body) == 1
502+
and isinstance(
503+
stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
504+
and infer_reachability_of_if_statement(
505+
stmt, self.options
506+
) is None # type: ignore[func-returns-value]
507+
and stmt.body[0].is_unreachable is False
508+
):
509+
current_overload_name = stmt.body[0].body[0].name
510+
last_if_stmt = stmt
511+
last_if_overload = stmt.body[0].body[0]
461512
else:
462513
current_overload = []
463514
current_overload_name = None

test-data/unit/check-overloading.test

+157
Original file line numberDiff line numberDiff line change
@@ -5189,3 +5189,160 @@ def register(cls: Any) -> Any: return None
51895189
x = register(Foo)
51905190
reveal_type(x) # N: Revealed type is "builtins.int"
51915191
[builtins fixtures/dict.pyi]
5192+
5193+
[case testOverloadIfBasic]
5194+
# flags: --always-true True
5195+
from typing import overload, Any
5196+
5197+
class A: ...
5198+
class B: ...
5199+
5200+
@overload
5201+
def f1(g: int) -> A: ...
5202+
if True:
5203+
@overload
5204+
def f1(g: str) -> B: ...
5205+
def f1(g: Any) -> Any: ...
5206+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5207+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5208+
5209+
@overload
5210+
def f2(g: int) -> A: ...
5211+
@overload
5212+
def f2(g: bytes) -> A: ...
5213+
if not True:
5214+
@overload
5215+
def f2(g: str) -> B: ...
5216+
def f2(g: Any) -> Any: ...
5217+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5218+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5219+
# N: Possible overload variants: \
5220+
# N: def f2(g: int) -> A \
5221+
# N: def f2(g: bytes) -> A \
5222+
# N: Revealed type is "Any"
5223+
5224+
[case testOverloadIfSysVersion]
5225+
# flags: --python-version 3.9
5226+
from typing import overload, Any
5227+
import sys
5228+
5229+
class A: ...
5230+
class B: ...
5231+
5232+
@overload
5233+
def f1(g: int) -> A: ...
5234+
if sys.version_info >= (3, 9):
5235+
@overload
5236+
def f1(g: str) -> B: ...
5237+
def f1(g: Any) -> Any: ...
5238+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5239+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5240+
5241+
@overload
5242+
def f2(g: int) -> A: ...
5243+
@overload
5244+
def f2(g: bytes) -> A: ...
5245+
if sys.version_info >= (3, 10):
5246+
@overload
5247+
def f2(g: str) -> B: ...
5248+
def f2(g: Any) -> Any: ...
5249+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5250+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5251+
# N: Possible overload variants: \
5252+
# N: def f2(g: int) -> A \
5253+
# N: def f2(g: bytes) -> A \
5254+
# N: Revealed type is "Any"
5255+
[builtins fixtures/tuple.pyi]
5256+
5257+
[case testOverloadIfMatching]
5258+
from typing import overload, Any
5259+
5260+
class A: ...
5261+
class B: ...
5262+
class C: ...
5263+
5264+
@overload
5265+
def f1(g: int) -> A: ...
5266+
if True:
5267+
# Some comment
5268+
@overload
5269+
def f1(g: str) -> B: ...
5270+
def f1(g: Any) -> Any: ...
5271+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5272+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5273+
5274+
@overload
5275+
def f2(g: int) -> A: ...
5276+
if True:
5277+
@overload
5278+
def f2(g: bytes) -> B: ...
5279+
@overload
5280+
def f2(g: str) -> C: ...
5281+
def f2(g: Any) -> Any: ...
5282+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5283+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.C"
5284+
5285+
@overload
5286+
def f3(g: int) -> A: ...
5287+
@overload
5288+
def f3(g: str) -> B: ...
5289+
if True:
5290+
def f3(g: Any) -> Any: ...
5291+
reveal_type(f3(42)) # N: Revealed type is "__main__.A"
5292+
reveal_type(f3("Hello")) # N: Revealed type is "__main__.B"
5293+
5294+
if True:
5295+
@overload
5296+
def f4(g: int) -> A: ...
5297+
@overload
5298+
def f4(g: str) -> B: ...
5299+
def f4(g: Any) -> Any: ...
5300+
reveal_type(f4(42)) # N: Revealed type is "__main__.A"
5301+
reveal_type(f4("Hello")) # N: Revealed type is "__main__.B"
5302+
5303+
if True:
5304+
# Some comment
5305+
@overload
5306+
def f5(g: int) -> A: ...
5307+
@overload
5308+
def f5(g: str) -> B: ...
5309+
def f5(g: Any) -> Any: ...
5310+
reveal_type(f5(42)) # N: Revealed type is "__main__.A"
5311+
reveal_type(f5("Hello")) # N: Revealed type is "__main__.B"
5312+
5313+
[case testOverloadIfNotMatching]
5314+
from typing import overload, Any
5315+
5316+
class A: ...
5317+
class B: ...
5318+
class C: ...
5319+
5320+
@overload # E: An overloaded function outside a stub file must have an implementation
5321+
def f1(g: int) -> A: ...
5322+
@overload
5323+
def f1(g: bytes) -> B: ...
5324+
if True:
5325+
@overload # E: Name "f1" already defined on line 7 \
5326+
# E: Single overload definition, multiple required
5327+
def f1(g: str) -> C: ...
5328+
pass # Some other action
5329+
def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7
5330+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5331+
reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \
5332+
# N: Possible overload variants: \
5333+
# N: def f1(g: int) -> A \
5334+
# N: def f1(g: bytes) -> B \
5335+
# N: Revealed type is "Any"
5336+
5337+
if True:
5338+
pass # Some other action
5339+
@overload # E: Single overload definition, multiple required
5340+
def f2(g: int) -> A: ...
5341+
@overload # E: Name "f2" already defined on line 21
5342+
def f2(g: bytes) -> B: ...
5343+
@overload
5344+
def f2(g: str) -> C: ...
5345+
def f2(g: Any) -> Any: ...
5346+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5347+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \
5348+
# E: Argument 1 to "f2" has incompatible type "str"; expected "int"

0 commit comments

Comments
 (0)