Skip to content

Commit c588852

Browse files
authored
Speed up freshening type variables (#14323)
Only perform type variable freshening if it's needed, i.e. there is a nested generic callable, since it's fairly expensive. Make the check for generic callables fast by creating a specialized type query visitor base class for queries with bool results. The visitor tries hard to avoid memory allocation in typical cases, since allocation is slow. This addresses at least some of the performance regression in #14095. This improved self-check performance by about 3% when compiled with mypyc (-O2). The new visitor class can potentially help with other type queries as well. I'll explore it in follow-up PRs.
1 parent c414464 commit c588852

File tree

3 files changed

+183
-4
lines changed

3 files changed

+183
-4
lines changed

mypy/expandtype.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

33
from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload
4+
from typing_extensions import Final
45

56
from mypy.nodes import ARG_POS, ARG_STAR, Var
67
from mypy.type_visitor import TypeTranslator
78
from mypy.types import (
9+
ANY_STRATEGY,
810
AnyType,
11+
BoolTypeQuery,
912
CallableType,
1013
DeletedType,
1114
ErasedType,
@@ -138,13 +141,30 @@ def freshen_function_type_vars(callee: F) -> F:
138141
return cast(F, fresh_overload)
139142

140143

144+
class HasGenericCallable(BoolTypeQuery):
145+
def __init__(self) -> None:
146+
super().__init__(ANY_STRATEGY)
147+
148+
def visit_callable_type(self, t: CallableType) -> bool:
149+
return t.is_generic() or super().visit_callable_type(t)
150+
151+
152+
# Share a singleton since this is performance sensitive
153+
has_generic_callable: Final = HasGenericCallable()
154+
155+
141156
T = TypeVar("T", bound=Type)
142157

143158

144159
def freshen_all_functions_type_vars(t: T) -> T:
145-
result = t.accept(FreshenCallableVisitor())
146-
assert isinstance(result, type(t))
147-
return result
160+
result: Type
161+
has_generic_callable.reset()
162+
if not t.accept(has_generic_callable):
163+
return t # Fast path to avoid expensive freshening
164+
else:
165+
result = t.accept(FreshenCallableVisitor())
166+
assert isinstance(result, type(t))
167+
return result
148168

149169

150170
class FreshenCallableVisitor(TypeTranslator):

mypy/type_visitor.py

+156
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from abc import abstractmethod
1717
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
18+
from typing_extensions import Final
1819

1920
from mypy_extensions import mypyc_attr, trait
2021

@@ -417,3 +418,158 @@ def visit_type_alias_type(self, t: TypeAliasType) -> T:
417418
def query_types(self, types: Iterable[Type]) -> T:
418419
"""Perform a query for a list of types using the strategy to combine the results."""
419420
return self.strategy([t.accept(self) for t in types])
421+
422+
423+
# Return True if at least one type component returns True
424+
ANY_STRATEGY: Final = 0
425+
# Return True if no type component returns False
426+
ALL_STRATEGY: Final = 1
427+
428+
429+
class BoolTypeQuery(SyntheticTypeVisitor[bool]):
430+
"""Visitor for performing recursive queries of types with a bool result.
431+
432+
Use TypeQuery if you need non-bool results.
433+
434+
'strategy' is used to combine results for a series of types. It must
435+
be ANY_STRATEGY or ALL_STRATEGY.
436+
437+
Note: This visitor keeps an internal state (tracks type aliases to avoid
438+
recursion), so it should *never* be re-used for querying different types
439+
unless you call reset() first.
440+
"""
441+
442+
def __init__(self, strategy: int) -> None:
443+
self.strategy = strategy
444+
if strategy == ANY_STRATEGY:
445+
self.default = False
446+
else:
447+
assert strategy == ALL_STRATEGY
448+
self.default = True
449+
# Keep track of the type aliases already visited. This is needed to avoid
450+
# infinite recursion on types like A = Union[int, List[A]]. An empty set is
451+
# represented as None as a micro-optimization.
452+
self.seen_aliases: set[TypeAliasType] | None = None
453+
# By default, we eagerly expand type aliases, and query also types in the
454+
# alias target. In most cases this is a desired behavior, but we may want
455+
# to skip targets in some cases (e.g. when collecting type variables).
456+
self.skip_alias_target = False
457+
458+
def reset(self) -> None:
459+
"""Clear mutable state (but preserve strategy).
460+
461+
This *must* be called if you want to reuse the visitor.
462+
"""
463+
self.seen_aliases = None
464+
465+
def visit_unbound_type(self, t: UnboundType) -> bool:
466+
return self.query_types(t.args)
467+
468+
def visit_type_list(self, t: TypeList) -> bool:
469+
return self.query_types(t.items)
470+
471+
def visit_callable_argument(self, t: CallableArgument) -> bool:
472+
return t.typ.accept(self)
473+
474+
def visit_any(self, t: AnyType) -> bool:
475+
return self.default
476+
477+
def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
478+
return self.default
479+
480+
def visit_none_type(self, t: NoneType) -> bool:
481+
return self.default
482+
483+
def visit_erased_type(self, t: ErasedType) -> bool:
484+
return self.default
485+
486+
def visit_deleted_type(self, t: DeletedType) -> bool:
487+
return self.default
488+
489+
def visit_type_var(self, t: TypeVarType) -> bool:
490+
return self.query_types([t.upper_bound] + t.values)
491+
492+
def visit_param_spec(self, t: ParamSpecType) -> bool:
493+
return self.default
494+
495+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
496+
return self.default
497+
498+
def visit_unpack_type(self, t: UnpackType) -> bool:
499+
return self.query_types([t.type])
500+
501+
def visit_parameters(self, t: Parameters) -> bool:
502+
return self.query_types(t.arg_types)
503+
504+
def visit_partial_type(self, t: PartialType) -> bool:
505+
return self.default
506+
507+
def visit_instance(self, t: Instance) -> bool:
508+
return self.query_types(t.args)
509+
510+
def visit_callable_type(self, t: CallableType) -> bool:
511+
# FIX generics
512+
# Avoid allocating any objects here as an optimization.
513+
args = self.query_types(t.arg_types)
514+
ret = t.ret_type.accept(self)
515+
if self.strategy == ANY_STRATEGY:
516+
return args or ret
517+
else:
518+
return args and ret
519+
520+
def visit_tuple_type(self, t: TupleType) -> bool:
521+
return self.query_types(t.items)
522+
523+
def visit_typeddict_type(self, t: TypedDictType) -> bool:
524+
return self.query_types(list(t.items.values()))
525+
526+
def visit_raw_expression_type(self, t: RawExpressionType) -> bool:
527+
return self.default
528+
529+
def visit_literal_type(self, t: LiteralType) -> bool:
530+
return self.default
531+
532+
def visit_star_type(self, t: StarType) -> bool:
533+
return t.type.accept(self)
534+
535+
def visit_union_type(self, t: UnionType) -> bool:
536+
return self.query_types(t.items)
537+
538+
def visit_overloaded(self, t: Overloaded) -> bool:
539+
return self.query_types(t.items) # type: ignore[arg-type]
540+
541+
def visit_type_type(self, t: TypeType) -> bool:
542+
return t.item.accept(self)
543+
544+
def visit_ellipsis_type(self, t: EllipsisType) -> bool:
545+
return self.default
546+
547+
def visit_placeholder_type(self, t: PlaceholderType) -> bool:
548+
return self.query_types(t.args)
549+
550+
def visit_type_alias_type(self, t: TypeAliasType) -> bool:
551+
# Skip type aliases already visited types to avoid infinite recursion.
552+
# TODO: Ideally we should fire subvisitors here (or use caching) if we care
553+
# about duplicates.
554+
if self.seen_aliases is None:
555+
self.seen_aliases = set()
556+
elif t in self.seen_aliases:
557+
return self.default
558+
self.seen_aliases.add(t)
559+
if self.skip_alias_target:
560+
return self.query_types(t.args)
561+
return get_proper_type(t).accept(self)
562+
563+
def query_types(self, types: list[Type] | tuple[Type, ...]) -> bool:
564+
"""Perform a query for a sequence of types using the strategy to combine the results."""
565+
# Special-case for lists and tuples to allow mypyc to produce better code.
566+
if isinstance(types, list):
567+
if self.strategy == ANY_STRATEGY:
568+
return any(t.accept(self) for t in types)
569+
else:
570+
return all(t.accept(self) for t in types)
571+
else:
572+
if self.strategy == ANY_STRATEGY:
573+
return any(t.accept(self) for t in types)
574+
else:
575+
return all(t.accept(self) for t in types)

mypy/types.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2879,7 +2879,10 @@ def get_proper_types(it: Iterable[Type | None]) -> list[ProperType] | list[Prope
28792879
# to make it easier to gradually get modules working with mypyc.
28802880
# Import them here, after the types are defined.
28812881
# This is intended as a re-export also.
2882-
from mypy.type_visitor import ( # noqa: F811
2882+
from mypy.type_visitor import ( # noqa: F811,F401
2883+
ALL_STRATEGY as ALL_STRATEGY,
2884+
ANY_STRATEGY as ANY_STRATEGY,
2885+
BoolTypeQuery as BoolTypeQuery,
28832886
SyntheticTypeVisitor as SyntheticTypeVisitor,
28842887
TypeQuery as TypeQuery,
28852888
TypeTranslator as TypeTranslator,

0 commit comments

Comments
 (0)