Skip to content

Commit 0334ebc

Browse files
Support for variadic type aliases (#15219)
Fixes #15062 Implementing "happy path" took like couple dozen lines, but there are *a lot* of edge cases, e.g. where we need to fail gracefully. Also of course I checked my implementation (mostly) works for recursive variadic aliases :-) see test. It looks like several pieces of support for proper variadic types (i.e. non-aliases, instances etc) are still missing, so I tried to fill in something where I needed it for type aliases, but not everywhere, some notable examples: * Type variable bound checks for instances are still broken, see TODO item in `semanal_typeargs.py` * I think type argument count check is still broken for instances (I think I fixed it for type aliases), there can be fewer than `len(type_vars) - 1` type arguments, e.g. if one of them is an unpack. * We should only prohibit multiple *variadic* unpacks in a type list, multiple fixed length unpacks are fine (I think I fixed this both for aliases and instances) Btw I was thinking about an example below, what should we do in such cases? ```python from typing import Tuple, TypeVar from typing_extensions import TypeVarTuple, Unpack T = TypeVar("T") S = TypeVar("S") Ts = TypeVarTuple("Ts") Alias = Tuple[T, S, Unpack[Ts], S] def foo(*x: Unpack[Ts]) -> None: y: Alias[Unpack[Ts], int, str] reveal_type(y) # <-- what is this type? # Ts = () => Tuple[int, str, str] # Ts = (bool) => Tuple[bool, int, str, int] # Ts = (bool, float) => Tuple[bool, float, int, str, float] ``` Finally, I noticed there is already some code duplication, and I am not improving it. I am open to suggestions on how to reduce the code duplication. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 391ed85 commit 0334ebc

12 files changed

+507
-61
lines changed

mypy/checkexpr.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,13 @@
151151
UninhabitedType,
152152
UnionType,
153153
UnpackType,
154+
flatten_nested_tuples,
154155
flatten_nested_unions,
155156
get_proper_type,
156157
get_proper_types,
157158
has_recursive_types,
158159
is_named_instance,
160+
split_with_prefix_and_suffix,
159161
)
160162
from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional
161163
from mypy.typestate import type_state
@@ -4070,6 +4072,35 @@ class LongName(Generic[T]): ...
40704072
# The _SpecialForm type can be used in some runtime contexts (e.g. it may have __or__).
40714073
return self.named_type("typing._SpecialForm")
40724074

4075+
def split_for_callable(
4076+
self, t: CallableType, args: Sequence[Type], ctx: Context
4077+
) -> list[Type]:
4078+
"""Handle directly applying type arguments to a variadic Callable.
4079+
4080+
This is needed in situations where e.g. variadic class object appears in
4081+
runtime context. For example:
4082+
class C(Generic[T, Unpack[Ts]]): ...
4083+
x = C[int, str]()
4084+
4085+
We simply group the arguments that need to go into Ts variable into a TupleType,
4086+
similar to how it is done in other places using split_with_prefix_and_suffix().
4087+
"""
4088+
vars = t.variables
4089+
if not vars or not any(isinstance(v, TypeVarTupleType) for v in vars):
4090+
return list(args)
4091+
4092+
prefix = next(i for (i, v) in enumerate(vars) if isinstance(v, TypeVarTupleType))
4093+
suffix = len(vars) - prefix - 1
4094+
args = flatten_nested_tuples(args)
4095+
if len(args) < len(vars) - 1:
4096+
self.msg.incompatible_type_application(len(vars), len(args), ctx)
4097+
return [AnyType(TypeOfAny.from_error)] * len(vars)
4098+
4099+
tvt = vars[prefix]
4100+
assert isinstance(tvt, TypeVarTupleType)
4101+
start, middle, end = split_with_prefix_and_suffix(tuple(args), prefix, suffix)
4102+
return list(start) + [TupleType(list(middle), tvt.tuple_fallback)] + list(end)
4103+
40734104
def apply_type_arguments_to_callable(
40744105
self, tp: Type, args: Sequence[Type], ctx: Context
40754106
) -> Type:
@@ -4083,19 +4114,28 @@ def apply_type_arguments_to_callable(
40834114
tp = get_proper_type(tp)
40844115

40854116
if isinstance(tp, CallableType):
4086-
if len(tp.variables) != len(args):
4117+
if len(tp.variables) != len(args) and not any(
4118+
isinstance(v, TypeVarTupleType) for v in tp.variables
4119+
):
40874120
if tp.is_type_obj() and tp.type_object().fullname == "builtins.tuple":
40884121
# TODO: Specialize the callable for the type arguments
40894122
return tp
40904123
self.msg.incompatible_type_application(len(tp.variables), len(args), ctx)
40914124
return AnyType(TypeOfAny.from_error)
4092-
return self.apply_generic_arguments(tp, args, ctx)
4125+
return self.apply_generic_arguments(tp, self.split_for_callable(tp, args, ctx), ctx)
40934126
if isinstance(tp, Overloaded):
40944127
for it in tp.items:
4095-
if len(it.variables) != len(args):
4128+
if len(it.variables) != len(args) and not any(
4129+
isinstance(v, TypeVarTupleType) for v in it.variables
4130+
):
40964131
self.msg.incompatible_type_application(len(it.variables), len(args), ctx)
40974132
return AnyType(TypeOfAny.from_error)
4098-
return Overloaded([self.apply_generic_arguments(it, args, ctx) for it in tp.items])
4133+
return Overloaded(
4134+
[
4135+
self.apply_generic_arguments(it, self.split_for_callable(it, args, ctx), ctx)
4136+
for it in tp.items
4137+
]
4138+
)
40994139
return AnyType(TypeOfAny.special_form)
41004140

41014141
def visit_list_expr(self, e: ListExpr) -> Type:

mypy/constraints.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Iterable, List, Sequence
5+
from typing import TYPE_CHECKING, Iterable, List, Sequence, cast
66
from typing_extensions import Final
77

88
import mypy.subtypes
@@ -46,15 +46,11 @@
4646
has_recursive_types,
4747
has_type_vars,
4848
is_named_instance,
49+
split_with_prefix_and_suffix,
4950
)
5051
from mypy.types_utils import is_union_with_any
5152
from mypy.typestate import type_state
52-
from mypy.typevartuples import (
53-
extract_unpack,
54-
find_unpack_in_list,
55-
split_with_mapped_and_template,
56-
split_with_prefix_and_suffix,
57-
)
53+
from mypy.typevartuples import extract_unpack, find_unpack_in_list, split_with_mapped_and_template
5854

5955
if TYPE_CHECKING:
6056
from mypy.infer import ArgumentInferContext
@@ -669,7 +665,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
669665
instance.type.type_var_tuple_prefix,
670666
instance.type.type_var_tuple_suffix,
671667
)
672-
tvars = list(tvars_prefix + tvars_suffix)
668+
tvars = cast("list[TypeVarLikeType]", list(tvars_prefix + tvars_suffix))
673669
else:
674670
mapped_args = mapped.args
675671
instance_args = instance.args
@@ -738,7 +734,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
738734
template.type.type_var_tuple_prefix,
739735
template.type.type_var_tuple_suffix,
740736
)
741-
tvars = list(tvars_prefix + tvars_suffix)
737+
tvars = cast("list[TypeVarLikeType]", list(tvars_prefix + tvars_suffix))
742738
else:
743739
mapped_args = mapped.args
744740
template_args = template.args

mypy/expandtype.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@
3737
UninhabitedType,
3838
UnionType,
3939
UnpackType,
40+
flatten_nested_tuples,
4041
flatten_nested_unions,
4142
get_proper_type,
42-
)
43-
from mypy.typevartuples import (
44-
find_unpack_in_list,
45-
split_with_instance,
4643
split_with_prefix_and_suffix,
4744
)
45+
from mypy.typevartuples import find_unpack_in_list, split_with_instance
4846

4947
# WARNING: these functions should never (directly or indirectly) depend on
5048
# is_subtype(), meet_types(), join_types() etc.
@@ -115,6 +113,7 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
115113
instance_args = instance.args
116114

117115
for binder, arg in zip(tvars, instance_args):
116+
assert isinstance(binder, TypeVarLikeType)
118117
variables[binder.id] = arg
119118

120119
return expand_type(typ, variables)
@@ -282,12 +281,14 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
282281
raise NotImplementedError
283282

284283
def visit_unpack_type(self, t: UnpackType) -> Type:
285-
# It is impossible to reasonally implement visit_unpack_type, because
284+
# It is impossible to reasonably implement visit_unpack_type, because
286285
# unpacking inherently expands to something more like a list of types.
287286
#
288287
# Relevant sections that can call unpack should call expand_unpack()
289288
# instead.
290-
assert False, "Mypy bug: unpacking must happen at a higher level"
289+
# However, if the item is a variadic tuple, we can simply carry it over.
290+
# it is hard to assert this without getting proper type.
291+
return UnpackType(t.type.accept(self))
291292

292293
def expand_unpack(self, t: UnpackType) -> list[Type] | Instance | AnyType | None:
293294
return expand_unpack_with_variables(t, self.variables)
@@ -356,7 +357,15 @@ def interpolate_args_for_unpack(
356357

357358
# Extract the typevartuple so we can get a tuple fallback from it.
358359
expanded_unpacked_tvt = expanded_unpack.type
359-
assert isinstance(expanded_unpacked_tvt, TypeVarTupleType)
360+
if isinstance(expanded_unpacked_tvt, TypeVarTupleType):
361+
fallback = expanded_unpacked_tvt.tuple_fallback
362+
else:
363+
# This can happen when tuple[Any, ...] is used to "patch" a variadic
364+
# generic type without type arguments provided.
365+
assert isinstance(expanded_unpacked_tvt, ProperType)
366+
assert isinstance(expanded_unpacked_tvt, Instance)
367+
assert expanded_unpacked_tvt.type.fullname == "builtins.tuple"
368+
fallback = expanded_unpacked_tvt
360369

361370
prefix_len = expanded_unpack_index
362371
arg_names = t.arg_names[:star_index] + [None] * prefix_len + t.arg_names[star_index:]
@@ -368,11 +377,7 @@ def interpolate_args_for_unpack(
368377
+ expanded_items[:prefix_len]
369378
# Constructing the Unpack containing the tuple without the prefix.
370379
+ [
371-
UnpackType(
372-
TupleType(
373-
expanded_items[prefix_len:], expanded_unpacked_tvt.tuple_fallback
374-
)
375-
)
380+
UnpackType(TupleType(expanded_items[prefix_len:], fallback))
376381
if len(expanded_items) - prefix_len > 1
377382
else expanded_items[0]
378383
]
@@ -456,9 +461,12 @@ def expand_types_with_unpack(
456461
indicates use of Any or some error occurred earlier. In this case callers should
457462
simply propagate the resulting type.
458463
"""
464+
# TODO: this will cause a crash on aliases like A = Tuple[int, Unpack[A]].
465+
# Although it is unlikely anyone will write this, we should fail gracefully.
466+
typs = flatten_nested_tuples(typs)
459467
items: list[Type] = []
460468
for item in typs:
461-
if isinstance(item, UnpackType):
469+
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
462470
unpacked_items = self.expand_unpack(item)
463471
if unpacked_items is None:
464472
# TODO: better error, something like tuple of unknown?
@@ -523,7 +531,11 @@ def visit_type_type(self, t: TypeType) -> Type:
523531
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
524532
# Target of the type alias cannot contain type variables (not bound by the type
525533
# alias itself), so we just expand the arguments.
526-
return t.copy_modified(args=self.expand_types(t.args))
534+
args = self.expand_types_with_unpack(t.args)
535+
if isinstance(args, list):
536+
return t.copy_modified(args=args)
537+
else:
538+
return args
527539

528540
def expand_types(self, types: Iterable[Type]) -> list[Type]:
529541
a: list[Type] = []

mypy/nodes.py

+5
Original file line numberDiff line numberDiff line change
@@ -3471,6 +3471,7 @@ def f(x: B[T]) -> T: ... # without T, Any would be used here
34713471
"normalized",
34723472
"_is_recursive",
34733473
"eager",
3474+
"tvar_tuple_index",
34743475
)
34753476

34763477
__match_args__ = ("name", "target", "alias_tvars", "no_args")
@@ -3498,6 +3499,10 @@ def __init__(
34983499
# it is the cached value.
34993500
self._is_recursive: bool | None = None
35003501
self.eager = eager
3502+
self.tvar_tuple_index = None
3503+
for i, t in enumerate(alias_tvars):
3504+
if isinstance(t, mypy.types.TypeVarTupleType):
3505+
self.tvar_tuple_index = i
35013506
super().__init__(line, column)
35023507

35033508
@classmethod

mypy/semanal.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@
270270
TypeOfAny,
271271
TypeType,
272272
TypeVarLikeType,
273+
TypeVarTupleType,
273274
TypeVarType,
274275
UnboundType,
275276
UnpackType,
@@ -3424,8 +3425,18 @@ def analyze_alias(
34243425
allowed_alias_tvars=tvar_defs,
34253426
)
34263427

3428+
# There can be only one variadic variable at most, the error is reported elsewhere.
3429+
new_tvar_defs = []
3430+
variadic = False
3431+
for td in tvar_defs:
3432+
if isinstance(td, TypeVarTupleType):
3433+
if variadic:
3434+
continue
3435+
variadic = True
3436+
new_tvar_defs.append(td)
3437+
34273438
qualified_tvars = [node.fullname for _name, node in found_type_vars]
3428-
return analyzed, tvar_defs, depends_on, qualified_tvars
3439+
return analyzed, new_tvar_defs, depends_on, qualified_tvars
34293440

34303441
def is_pep_613(self, s: AssignmentStmt) -> bool:
34313442
if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType):

mypy/semanal_typeargs.py

+46-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from mypy.options import Options
1919
from mypy.scope import Scope
2020
from mypy.subtypes import is_same_type, is_subtype
21+
from mypy.typeanal import set_any_tvars
2122
from mypy.types import (
2223
AnyType,
2324
Instance,
@@ -32,8 +33,10 @@
3233
TypeVarType,
3334
UnboundType,
3435
UnpackType,
36+
flatten_nested_tuples,
3537
get_proper_type,
3638
get_proper_types,
39+
split_with_prefix_and_suffix,
3740
)
3841

3942

@@ -79,10 +82,34 @@ def visit_type_alias_type(self, t: TypeAliasType) -> None:
7982
self.seen_aliases.add(t)
8083
# Some recursive aliases may produce spurious args. In principle this is not very
8184
# important, as we would simply ignore them when expanding, but it is better to keep
82-
# correct aliases.
83-
if t.alias and len(t.args) != len(t.alias.alias_tvars):
84-
t.args = [AnyType(TypeOfAny.from_error) for _ in t.alias.alias_tvars]
85+
# correct aliases. Also, variadic aliases are better to check when fully analyzed,
86+
# so we do this here.
8587
assert t.alias is not None, f"Unfixed type alias {t.type_ref}"
88+
args = flatten_nested_tuples(t.args)
89+
if t.alias.tvar_tuple_index is not None:
90+
correct = len(args) >= len(t.alias.alias_tvars) - 1
91+
if any(
92+
isinstance(a, UnpackType) and isinstance(get_proper_type(a.type), Instance)
93+
for a in args
94+
):
95+
correct = True
96+
else:
97+
correct = len(args) == len(t.alias.alias_tvars)
98+
if not correct:
99+
if t.alias.tvar_tuple_index is not None:
100+
exp_len = f"at least {len(t.alias.alias_tvars) - 1}"
101+
else:
102+
exp_len = f"{len(t.alias.alias_tvars)}"
103+
self.fail(
104+
f"Bad number of arguments for type alias, expected: {exp_len}, given: {len(args)}",
105+
t,
106+
code=codes.TYPE_ARG,
107+
)
108+
t.args = set_any_tvars(
109+
t.alias, t.line, t.column, self.options, from_error=True, fail=self.fail
110+
).args
111+
else:
112+
t.args = args
86113
is_error = self.validate_args(t.alias.name, t.args, t.alias.alias_tvars, t)
87114
if not is_error:
88115
# If there was already an error for the alias itself, there is no point in checking
@@ -101,6 +128,17 @@ def visit_instance(self, t: Instance) -> None:
101128
def validate_args(
102129
self, name: str, args: Sequence[Type], type_vars: list[TypeVarLikeType], ctx: Context
103130
) -> bool:
131+
# TODO: we need to do flatten_nested_tuples and validate arg count for instances
132+
# similar to how do we do this for type aliases above, but this may have perf penalty.
133+
if any(isinstance(v, TypeVarTupleType) for v in type_vars):
134+
prefix = next(i for (i, v) in enumerate(type_vars) if isinstance(v, TypeVarTupleType))
135+
tvt = type_vars[prefix]
136+
assert isinstance(tvt, TypeVarTupleType)
137+
start, middle, end = split_with_prefix_and_suffix(
138+
tuple(args), prefix, len(type_vars) - prefix - 1
139+
)
140+
args = list(start) + [TupleType(list(middle), tvt.tuple_fallback)] + list(end)
141+
104142
is_error = False
105143
for (i, arg), tvar in zip(enumerate(args), type_vars):
106144
if isinstance(tvar, TypeVarType):
@@ -167,7 +205,11 @@ def visit_unpack_type(self, typ: UnpackType) -> None:
167205
return
168206
if isinstance(proper_type, Instance) and proper_type.type.fullname == "builtins.tuple":
169207
return
170-
if isinstance(proper_type, AnyType) and proper_type.type_of_any == TypeOfAny.from_error:
208+
if (
209+
isinstance(proper_type, UnboundType)
210+
or isinstance(proper_type, AnyType)
211+
and proper_type.type_of_any == TypeOfAny.from_error
212+
):
171213
return
172214

173215
# TODO: Infer something when it can't be unpacked to allow rest of

mypy/subtypes.py

+2
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,8 @@ def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool:
661661
def visit_unpack_type(self, left: UnpackType) -> bool:
662662
if isinstance(self.right, UnpackType):
663663
return self._is_subtype(left.type, self.right.type)
664+
if isinstance(self.right, Instance) and self.right.type.fullname == "builtins.object":
665+
return True
664666
return False
665667

666668
def visit_parameters(self, left: Parameters) -> bool:

0 commit comments

Comments
 (0)