Skip to content

Commit 99436ee

Browse files
committed
Add support for conditionally defined overloads
1 parent e32f35b commit 99436ee

File tree

2 files changed

+209
-1
lines changed

2 files changed

+209
-1
lines changed

mypy/fastparse.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from mypy import message_registry, errorcodes as codes
4040
from mypy.errors import Errors
4141
from mypy.options import Options
42-
from mypy.reachability import mark_block_unreachable
42+
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
4343

4444
try:
4545
# pull this into a final variable to make mypyc be quiet about the
@@ -447,12 +447,50 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
447447
ret: List[Statement] = []
448448
current_overload: List[OverloadPart] = []
449449
current_overload_name: Optional[str] = None
450+
last_if_stmt: Optional[IfStmt] = None
451+
last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None
450452
for stmt in stmts:
451453
if (current_overload_name is not None
452454
and isinstance(stmt, (Decorator, FuncDef))
453455
and stmt.name == current_overload_name):
456+
if last_if_overload is not None:
457+
if isinstance(last_if_overload, OverloadedFuncDef):
458+
current_overload.extend(last_if_overload.items)
459+
else:
460+
current_overload.append(last_if_overload)
461+
last_if_stmt, last_if_overload = None, None
454462
current_overload.append(stmt)
463+
elif (
464+
current_overload_name is not None
465+
and isinstance(stmt, IfStmt)
466+
and len(stmt.body[0].body) == 1
467+
and isinstance(
468+
stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
469+
and stmt.body[0].body[0].name == current_overload_name
470+
):
471+
# IfStmt only contains stmts relevant to current_overload.
472+
# Check if stmts are reachable and add them to current_overload,
473+
# otherwise skip IfStmt to allow subsequent overload
474+
# or function definitions.
475+
infer_reachability_of_if_statement(stmt, self.options)
476+
if stmt.body[0].is_unreachable is True:
477+
continue
478+
if last_if_overload is not None:
479+
if isinstance(last_if_overload, OverloadedFuncDef):
480+
current_overload.extend(last_if_overload.items)
481+
else:
482+
current_overload.append(last_if_overload)
483+
last_if_stmt, last_if_overload = None, None
484+
last_if_overload = None
485+
if isinstance(stmt.body[0].body[0], OverloadedFuncDef):
486+
current_overload.extend(stmt.body[0].body[0].items)
487+
else:
488+
current_overload.append(stmt.body[0].body[0])
455489
else:
490+
if last_if_stmt is not None:
491+
ret.append(last_if_stmt)
492+
last_if_stmt, last_if_overload = None, None
493+
456494
if len(current_overload) == 1:
457495
ret.append(current_overload[0])
458496
elif len(current_overload) > 1:
@@ -466,6 +504,19 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
466504
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
467505
current_overload = [stmt]
468506
current_overload_name = stmt.name
507+
elif (
508+
isinstance(stmt, IfStmt)
509+
and len(stmt.body[0].body) == 1
510+
and isinstance(
511+
stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
512+
and infer_reachability_of_if_statement(
513+
stmt, self.options
514+
) is None # type: ignore[func-returns-value]
515+
and stmt.body[0].is_unreachable is False
516+
):
517+
current_overload_name = stmt.body[0].body[0].name
518+
last_if_stmt = stmt
519+
last_if_overload = stmt.body[0].body[0]
469520
else:
470521
current_overload = []
471522
current_overload_name = None

test-data/unit/check-overloading.test

+157
Original file line numberDiff line numberDiff line change
@@ -5339,3 +5339,160 @@ def register(cls: Any) -> Any: return None
53395339
x = register(Foo)
53405340
reveal_type(x) # N: Revealed type is "builtins.int"
53415341
[builtins fixtures/dict.pyi]
5342+
5343+
[case testOverloadIfBasic]
5344+
# flags: --always-true True
5345+
from typing import overload, Any
5346+
5347+
class A: ...
5348+
class B: ...
5349+
5350+
@overload
5351+
def f1(g: int) -> A: ...
5352+
if True:
5353+
@overload
5354+
def f1(g: str) -> B: ...
5355+
def f1(g: Any) -> Any: ...
5356+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5357+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5358+
5359+
@overload
5360+
def f2(g: int) -> A: ...
5361+
@overload
5362+
def f2(g: bytes) -> A: ...
5363+
if not True:
5364+
@overload
5365+
def f2(g: str) -> B: ...
5366+
def f2(g: Any) -> Any: ...
5367+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5368+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5369+
# N: Possible overload variants: \
5370+
# N: def f2(g: int) -> A \
5371+
# N: def f2(g: bytes) -> A \
5372+
# N: Revealed type is "Any"
5373+
5374+
[case testOverloadIfSysVersion]
5375+
# flags: --python-version 3.9
5376+
from typing import overload, Any
5377+
import sys
5378+
5379+
class A: ...
5380+
class B: ...
5381+
5382+
@overload
5383+
def f1(g: int) -> A: ...
5384+
if sys.version_info >= (3, 9):
5385+
@overload
5386+
def f1(g: str) -> B: ...
5387+
def f1(g: Any) -> Any: ...
5388+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5389+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5390+
5391+
@overload
5392+
def f2(g: int) -> A: ...
5393+
@overload
5394+
def f2(g: bytes) -> A: ...
5395+
if sys.version_info >= (3, 10):
5396+
@overload
5397+
def f2(g: str) -> B: ...
5398+
def f2(g: Any) -> Any: ...
5399+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5400+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5401+
# N: Possible overload variants: \
5402+
# N: def f2(g: int) -> A \
5403+
# N: def f2(g: bytes) -> A \
5404+
# N: Revealed type is "Any"
5405+
[builtins fixtures/tuple.pyi]
5406+
5407+
[case testOverloadIfMatching]
5408+
from typing import overload, Any
5409+
5410+
class A: ...
5411+
class B: ...
5412+
class C: ...
5413+
5414+
@overload
5415+
def f1(g: int) -> A: ...
5416+
if True:
5417+
# Some comment
5418+
@overload
5419+
def f1(g: str) -> B: ...
5420+
def f1(g: Any) -> Any: ...
5421+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5422+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5423+
5424+
@overload
5425+
def f2(g: int) -> A: ...
5426+
if True:
5427+
@overload
5428+
def f2(g: bytes) -> B: ...
5429+
@overload
5430+
def f2(g: str) -> C: ...
5431+
def f2(g: Any) -> Any: ...
5432+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5433+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.C"
5434+
5435+
@overload
5436+
def f3(g: int) -> A: ...
5437+
@overload
5438+
def f3(g: str) -> B: ...
5439+
if True:
5440+
def f3(g: Any) -> Any: ...
5441+
reveal_type(f3(42)) # N: Revealed type is "__main__.A"
5442+
reveal_type(f3("Hello")) # N: Revealed type is "__main__.B"
5443+
5444+
if True:
5445+
@overload
5446+
def f4(g: int) -> A: ...
5447+
@overload
5448+
def f4(g: str) -> B: ...
5449+
def f4(g: Any) -> Any: ...
5450+
reveal_type(f4(42)) # N: Revealed type is "__main__.A"
5451+
reveal_type(f4("Hello")) # N: Revealed type is "__main__.B"
5452+
5453+
if True:
5454+
# Some comment
5455+
@overload
5456+
def f5(g: int) -> A: ...
5457+
@overload
5458+
def f5(g: str) -> B: ...
5459+
def f5(g: Any) -> Any: ...
5460+
reveal_type(f5(42)) # N: Revealed type is "__main__.A"
5461+
reveal_type(f5("Hello")) # N: Revealed type is "__main__.B"
5462+
5463+
[case testOverloadIfNotMatching]
5464+
from typing import overload, Any
5465+
5466+
class A: ...
5467+
class B: ...
5468+
class C: ...
5469+
5470+
@overload # E: An overloaded function outside a stub file must have an implementation
5471+
def f1(g: int) -> A: ...
5472+
@overload
5473+
def f1(g: bytes) -> B: ...
5474+
if True:
5475+
@overload # E: Name "f1" already defined on line 7 \
5476+
# E: Single overload definition, multiple required
5477+
def f1(g: str) -> C: ...
5478+
pass # Some other action
5479+
def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7
5480+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5481+
reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \
5482+
# N: Possible overload variants: \
5483+
# N: def f1(g: int) -> A \
5484+
# N: def f1(g: bytes) -> B \
5485+
# N: Revealed type is "Any"
5486+
5487+
if True:
5488+
pass # Some other action
5489+
@overload # E: Single overload definition, multiple required
5490+
def f2(g: int) -> A: ...
5491+
@overload # E: Name "f2" already defined on line 21
5492+
def f2(g: bytes) -> B: ...
5493+
@overload
5494+
def f2(g: str) -> C: ...
5495+
def f2(g: Any) -> Any: ...
5496+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5497+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \
5498+
# E: Argument 1 to "f2" has incompatible type "str"; expected "int"

0 commit comments

Comments
 (0)