Skip to content

Speed up ArgKind methods by changing them into top-level functions #11546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
for ai, actual_kind in enumerate(actual_kinds):
if actual_kind == nodes.ARG_POS:
if fi < nformals:
if not formal_kinds[fi].is_star():
if not nodes.is_star(formal_kinds[fi]):
formal_to_actual[fi].append(ai)
fi += 1
elif formal_kinds[fi] == nodes.ARG_STAR:
Expand All @@ -55,14 +55,14 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
# Assume that it is an iterable (if it isn't, there will be
# an error later).
while fi < nformals:
if formal_kinds[fi].is_named(star=True):
if nodes.is_named(formal_kinds[fi], star=True):
break
else:
formal_to_actual[fi].append(ai)
if formal_kinds[fi] == nodes.ARG_STAR:
break
fi += 1
elif actual_kind.is_named():
elif nodes.is_named(actual_kind):
assert actual_names is not None, "Internal error: named kinds without names given"
name = actual_names[ai]
if name in formal_names:
Expand Down
10 changes: 5 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
Instance, NoneType, strip_type, TypeType, TypeOfAny,
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType,
is_named_instance, union_items, TypeQuery, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
is_optional_type, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType)
from mypy.sametypes import is_same_type
from mypy.messages import (
Expand Down Expand Up @@ -4512,11 +4512,11 @@ def has_no_custom_eq_checks(t: Type) -> bool:
collection_type = operand_types[right_index]

# We only try and narrow away 'None' for now
if not is_optional(item_type):
if not is_optional_type(item_type):
continue

collection_item_type = get_proper_type(builtin_item_type(collection_type))
if collection_item_type is None or is_optional(collection_item_type):
if collection_item_type is None or is_optional_type(collection_item_type):
continue
if (isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == 'builtins.object'):
Expand Down Expand Up @@ -4904,7 +4904,7 @@ def refine_away_none_in_comparison(self,
non_optional_types = []
for i in chain_indices:
typ = operand_types[i]
if not is_optional(typ):
if not is_optional_type(typ):
non_optional_types.append(typ)

# Make sure we have a mixture of optional and non-optional types.
Expand All @@ -4914,7 +4914,7 @@ def refine_away_none_in_comparison(self,
if_map = {}
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_optional(expr_type):
if not is_optional_type(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)
Expand Down
17 changes: 9 additions & 8 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TupleType, TypedDictType, Instance, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
is_named_instance, FunctionLike, ParamSpecType,
StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType,
StarType, is_optional_type, remove_optional, is_generic_instance, get_proper_type, ProperType,
get_proper_types, flatten_nested_unions
)
from mypy.nodes import (
Expand All @@ -34,6 +34,7 @@
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
ParamSpecExpr,
ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
is_required, is_positional, is_star, is_named
)
from mypy.literals import literal
from mypy import nodes
Expand Down Expand Up @@ -1127,7 +1128,7 @@ def infer_arg_types_in_context(

for i, actuals in enumerate(formal_to_actual):
for ai in actuals:
if not arg_kinds[ai].is_star():
if not is_star(arg_kinds[ai]):
res[ai] = self.accept(args[ai], callee.arg_types[i])

# Fill in the rest of the argument types.
Expand Down Expand Up @@ -1155,7 +1156,7 @@ def infer_function_type_arguments_using_context(
# valid results.
erased_ctx = replace_meta_vars(ctx, ErasedType())
ret_type = callable.ret_type
if is_optional(ret_type) and is_optional(ctx):
if is_optional_type(ret_type) and is_optional_type(ctx):
# If both the context and the return type are optional, unwrap the optional,
# since in 99% cases this is what a user expects. In other words, we replace
# Optional[T] <: Optional[int]
Expand Down Expand Up @@ -1389,24 +1390,24 @@ def check_argument_count(self,

# Check for too many or few values for formals.
for i, kind in enumerate(callee.arg_kinds):
if kind.is_required() and not formal_to_actual[i] and not is_unexpected_arg_error:
if is_required(kind) and not formal_to_actual[i] and not is_unexpected_arg_error:
# No actual for a mandatory formal
if messages:
if kind.is_positional():
if is_positional(kind):
messages.too_few_arguments(callee, context, actual_names)
else:
argname = callee.arg_names[i] or "?"
messages.missing_named_argument(callee, context, argname)
ok = False
elif not kind.is_star() and is_duplicate_mapping(
elif not is_star(kind) and is_duplicate_mapping(
formal_to_actual[i], actual_types, actual_kinds):
if (self.chk.in_checked_function() or
isinstance(get_proper_type(actual_types[formal_to_actual[i][0]]),
TupleType)):
if messages:
messages.duplicate_argument_value(callee, i, context)
ok = False
elif (kind.is_named() and formal_to_actual[i] and
elif (is_named(kind) and formal_to_actual[i] and
actual_kinds[formal_to_actual[i][0]] not in [nodes.ARG_NAMED, nodes.ARG_STAR2]):
# Positional argument when expecting a keyword argument.
if messages:
Expand Down Expand Up @@ -1948,7 +1949,7 @@ def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, C
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
if new_kind == target_kind:
continue
elif new_kind.is_positional() and target_kind.is_positional():
elif is_positional(new_kind) and is_positional(target_kind):
new_kinds[i] = ARG_POS
else:
too_complex = True
Expand Down
4 changes: 2 additions & 2 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
is_protocol_implementation, find_member
)
from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT
from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT, is_named
import mypy.typeops
from mypy import state

Expand Down Expand Up @@ -532,7 +532,7 @@ def combine_arg_names(t: CallableType, s: CallableType) -> List[Optional[str]]:
for i in range(num_args):
t_name = t.arg_names[i]
s_name = s.arg_names[i]
if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
if t_name == s_name or is_named(t.arg_kinds[i]) or is_named(s.arg_kinds[i]):
new_names.append(t_name)
else:
new_names.append(None)
Expand Down
11 changes: 6 additions & 5 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
TypeInfo, Context, MypyFile, FuncDef, reverse_builtin_aliases,
ArgKind, ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode,
CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode, SYMBOL_FUNCBASE_TYPES
CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode, SYMBOL_FUNCBASE_TYPES, is_positional,
is_named, is_star, is_optional
)
from mypy.operators import op_methods, op_methods_to_symbols
from mypy.subtypes import (
Expand Down Expand Up @@ -1764,12 +1765,12 @@ def format(typ: Type) -> str:
for arg_name, arg_type, arg_kind in zip(
func.arg_names, func.arg_types, func.arg_kinds):
if (arg_kind == ARG_POS and arg_name is None
or verbosity == 0 and arg_kind.is_positional()):
or verbosity == 0 and is_positional(arg_kind)):

arg_strings.append(format(arg_type))
else:
constructor = ARG_CONSTRUCTOR_NAMES[arg_kind]
if arg_kind.is_star() or arg_name is None:
if is_star(arg_kind) or arg_name is None:
arg_strings.append("{}({})".format(
constructor,
format(arg_type)))
Expand Down Expand Up @@ -1912,7 +1913,7 @@ def [T <: int] f(self, x: int, y: T) -> None
for i in range(len(tp.arg_types)):
if s:
s += ', '
if tp.arg_kinds[i].is_named() and not asterisk:
if is_named(tp.arg_kinds[i]) and not asterisk:
s += '*, '
asterisk = True
if tp.arg_kinds[i] == ARG_STAR:
Expand All @@ -1924,7 +1925,7 @@ def [T <: int] f(self, x: int, y: T) -> None
if name:
s += name + ': '
s += format_type_bare(tp.arg_types[i])
if tp.arg_kinds[i].is_optional():
if is_optional(tp.arg_kinds[i]):
s += ' = ...'

# If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list
Expand Down
43 changes: 24 additions & 19 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,9 +1625,9 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_member_expr(self)


# Kinds of arguments
@unique
class ArgKind(Enum):
"""Kinds of arguments"""
# Positional argument
ARG_POS = 0
# Positional, optional argument (functions only, not calls)
Expand All @@ -1641,28 +1641,33 @@ class ArgKind(Enum):
# In an argument list, keyword-only and also optional
ARG_NAMED_OPT = 5

def is_positional(self, star: bool = False) -> bool:
return (
self == ARG_POS
or self == ARG_OPT
or (star and self == ARG_STAR)
)

def is_named(self, star: bool = False) -> bool:
return (
self == ARG_NAMED
or self == ARG_NAMED_OPT
or (star and self == ARG_STAR2)
)
def is_positional(kind: ArgKind, star: bool = False) -> bool:
return (
kind == ARG_POS
or kind == ARG_OPT
or (star and kind == ARG_STAR)
)


def is_named(kind: ArgKind, star: bool = False) -> bool:
return (
kind == ARG_NAMED
or kind == ARG_NAMED_OPT
or (star and kind == ARG_STAR2)
)


def is_required(kind: ArgKind) -> bool:
return kind == ARG_POS or kind == ARG_NAMED


def is_required(self) -> bool:
return self == ARG_POS or self == ARG_NAMED
def is_optional(kind: ArgKind) -> bool:
return kind == ARG_OPT or kind == ARG_NAMED_OPT

def is_optional(self) -> bool:
return self == ARG_OPT or self == ARG_NAMED_OPT

def is_star(self) -> bool:
return self == ARG_STAR or self == ARG_STAR2
def is_star(kind: ArgKind) -> bool:
return kind == ARG_STAR or kind == ARG_STAR2


ARG_POS: Final = ArgKind.ARG_POS
Expand Down
6 changes: 3 additions & 3 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypy.nodes import (
ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Expression, JsonDict, NameExpr, RefExpr,
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode, is_named
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
Expand Down Expand Up @@ -495,8 +495,8 @@ def _collect_field_args(expr: Expression,
# field() only takes keyword arguments.
args = {}
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
if not kind.is_named():
if kind.is_named(star=True):
if not is_named(kind):
if is_named(kind, star=True):
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
Expand Down
4 changes: 2 additions & 2 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing_extensions import Final

import mypy.plugin
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var, is_positional
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType

Expand Down Expand Up @@ -66,7 +66,7 @@ def _find_other_type(method: _MethodInfo) -> Type:
cur_pos_arg = 0
other_arg = None
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
if arg_kind.is_positional():
if is_positional(arg_kind):
if cur_pos_arg == first_arg_pos:
other_arg = arg_type
break
Expand Down
4 changes: 2 additions & 2 deletions mypy/plugins/singledispatch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mypy.messages import format_type
from mypy.plugins.common import add_method_to_class
from mypy.nodes import (
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context, is_positional
)
from mypy.subtypes import is_subtype
from mypy.types import (
Expand Down Expand Up @@ -100,7 +100,7 @@ def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
)
return ctx.default_return_type

elif not func_type.arg_kinds[0].is_positional(star=True):
elif not is_positional(func_type.arg_kinds[0], star=True):
fail(
ctx,
'First argument to singledispatch function must be a positional argument',
Expand Down
5 changes: 2 additions & 3 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@
get_nongen_builtins, get_member_expr_fullname, REVEAL_TYPE,
REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_source_versions,
EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr,
ParamSpecExpr, EllipsisExpr,
FuncBase, implicit_module_attrs,
ParamSpecExpr, EllipsisExpr, FuncBase, implicit_module_attrs, is_named
)
from mypy.tvar_scope import TypeVarLikeScope
from mypy.typevars import fill_typevars
Expand Down Expand Up @@ -3142,7 +3141,7 @@ def process_typevar_parameters(self, args: List[Expression],
contravariant = False
upper_bound: Type = self.object_type()
for param_value, param_name, param_kind in zip(args, names, kinds):
if not param_kind.is_named():
if not is_named(param_kind):
self.fail("Unexpected argument to TypeVar()", context)
return None
if param_name == 'covariant':
Expand Down
4 changes: 2 additions & 2 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def func_helper(self, o: 'mypy.nodes.FuncItem') -> List[object]:
extra: List[Tuple[str, List[mypy.nodes.Var]]] = []
for arg in o.arguments:
kind: mypy.nodes.ArgKind = arg.kind
if kind.is_required():
if mypy.nodes.is_required(kind):
args.append(arg.variable)
elif kind.is_optional():
elif mypy.nodes.is_optional(kind):
assert arg.initializer is not None
args.append(('default', [arg.variable, arg.initializer]))
elif kind == mypy.nodes.ARG_STAR:
Expand Down
4 changes: 2 additions & 2 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr,
ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo,
IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block,
Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED,
Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, is_named
)
from mypy.stubgenc import generate_stub_for_c_module
from mypy.stubutil import (
Expand Down Expand Up @@ -650,7 +650,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
if not isinstance(get_proper_type(annotated_type), AnyType):
annotation = ": {}".format(self.print_annotation(annotated_type))
if arg_.initializer:
if kind.is_named() and not any(arg.startswith('*') for arg in args):
if is_named(kind) and not any(arg.startswith('*') for arg in args):
args.append('*')
if not annotation:
typename = self.get_str_type_of_node(arg_.initializer, True, False)
Expand Down
Loading