Skip to content

Commit 14eaa18

Browse files
authored
Plugin for tuple multiplication with literal int (#10361)
1 parent 21991cf commit 14eaa18

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

mypy/plugins/default.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mypy.plugins.common import try_getting_str_literals
1111
from mypy.types import (
1212
FunctionLike, Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType,
13-
TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType
13+
TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType, TupleType
1414
)
1515
from mypy.subtypes import is_subtype
1616
from mypy.typeops import make_simplified_union
@@ -64,6 +64,8 @@ def get_method_hook(self, fullname: str
6464
return int_pow_callback
6565
elif fullname == 'builtins.int.__neg__':
6666
return int_neg_callback
67+
elif fullname in ('builtins.tuple.__mul__', 'builtins.tuple.__rmul__'):
68+
return tuple_mul_callback
6769
elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES):
6870
return typed_dict_setdefault_callback
6971
elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES):
@@ -471,3 +473,24 @@ def int_neg_callback(ctx: MethodContext) -> Type:
471473
if isinstance(value, int):
472474
return LiteralType(value=-value, fallback=fallback)
473475
return ctx.default_return_type
476+
477+
478+
def tuple_mul_callback(ctx: MethodContext) -> Type:
479+
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
480+
481+
This is used to return a specific sized tuple if multiplied by Literal int
482+
"""
483+
if not isinstance(ctx.type, TupleType):
484+
return ctx.default_return_type
485+
486+
arg_type = ctx.arg_types[0][0]
487+
if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
488+
value = arg_type.last_known_value.value
489+
if isinstance(value, int):
490+
return ctx.type.copy_modified(items=ctx.type.items * value)
491+
elif isinstance(ctx.type, LiteralType):
492+
value = arg_type.value
493+
if isinstance(value, int):
494+
return ctx.type.copy_modified(items=ctx.type.items * value)
495+
496+
return ctx.default_return_type

test-data/unit/check-generics.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1672,7 +1672,7 @@ def f(x: T) -> str:
16721672
[case testTypeVarReversibleOperatorTuple]
16731673
from typing import TypeVar, Tuple
16741674
class A(Tuple[int, int]):
1675-
def __mul__(cls, other: Tuple[int, int]) -> str: return ""
1675+
def __mul__(cls, other: Tuple[int, int]) -> str: return "" # type: ignore # overriding default __mul__
16761676
T = TypeVar("T", bound=A)
16771677
def f(x: T) -> str:
16781678
return reveal_type(x * (1, 2) ) # N: Revealed type is "builtins.str"

test-data/unit/check-tuples.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,32 @@ x9, y9, x10, y10, z5 = *points2, 1, *points2 # E: Contiguous iterable with same
14711471
() = 1 # E: "Literal[1]?" object is not iterable
14721472
[builtins fixtures/tuple.pyi]
14731473

1474+
[case testMultiplyTupleByIntegerLiteral]
1475+
from typing import Tuple
1476+
t = ('',) * 2
1477+
reveal_type(t) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
1478+
t2 = ('',) * -1
1479+
reveal_type(t2) # N: Revealed type is "Tuple[]"
1480+
t3 = ('', 1) * 2
1481+
reveal_type(t3) # N: Revealed type is "Tuple[builtins.str, builtins.int, builtins.str, builtins.int]"
1482+
def f() -> Tuple[str, ...]:
1483+
return ('', )
1484+
reveal_type(f() * 2) # N: Revealed type is "builtins.tuple[builtins.str*]"
1485+
[builtins fixtures/tuple.pyi]
1486+
1487+
[case testMultiplyTupleByIntegerLiteralReverse]
1488+
from typing import Tuple
1489+
t = 2 * ('',)
1490+
reveal_type(t) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
1491+
t2 = -1 * ('',)
1492+
reveal_type(t2) # N: Revealed type is "Tuple[]"
1493+
t3 = 2 * ('', 1)
1494+
reveal_type(t3) # N: Revealed type is "Tuple[builtins.str, builtins.int, builtins.str, builtins.int]"
1495+
def f() -> Tuple[str, ...]:
1496+
return ('', )
1497+
reveal_type(2 * f()) # N: Revealed type is "builtins.tuple[builtins.str*]"
1498+
[builtins fixtures/tuple.pyi]
1499+
14741500
[case testSingleUndefinedTypeAndTuple]
14751501
from typing import Tuple
14761502

test-data/unit/fixtures/tuple.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class tuple(Sequence[Tco], Generic[Tco]):
1414
def __iter__(self) -> Iterator[Tco]: pass
1515
def __contains__(self, item: object) -> bool: pass
1616
def __getitem__(self, x: int) -> Tco: pass
17+
def __mul__(self, n: int) -> Tuple[Tco, ...]: pass
1718
def __rmul__(self, n: int) -> Tuple[Tco, ...]: pass
1819
def __add__(self, x: Tuple[Tco, ...]) -> Tuple[Tco, ...]: pass
1920
def count(self, obj: object) -> int: pass

0 commit comments

Comments
 (0)