From e7882c21064cada2bc87c63852dd04153f8e9253 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Sun, 7 Nov 2021 11:04:51 +0000 Subject: [PATCH 1/2] Speed up ArgKind methods by changing them into top-level functions Mypyc can't call enum methods as native methods. This seems to speed up compiled mypy by around 1%. --- mypy/argmap.py | 6 +++--- mypy/checker.py | 10 ++++----- mypy/checkexpr.py | 17 ++++++++-------- mypy/join.py | 4 ++-- mypy/messages.py | 11 +++++----- mypy/nodes.py | 37 +++++++++++++++++----------------- mypy/plugins/dataclasses.py | 6 +++--- mypy/plugins/functools.py | 4 ++-- mypy/plugins/singledispatch.py | 4 ++-- mypy/strconv.py | 4 ++-- mypy/stubgen.py | 4 ++-- mypy/stubtest.py | 14 ++++++------- mypy/subtypes.py | 7 ++++--- mypy/suggestions.py | 12 +++++------ mypy/typeanal.py | 4 ++-- mypy/types.py | 25 ++++++++++++----------- mypyc/ir/func_ir.py | 4 ++-- mypyc/irbuild/function.py | 8 ++++---- mypyc/irbuild/ll_builder.py | 20 +++++++++--------- 19 files changed, 104 insertions(+), 97 deletions(-) diff --git a/mypy/argmap.py b/mypy/argmap.py index cb3811161783..3e84ba61c766 100644 --- a/mypy/argmap.py +++ b/mypy/argmap.py @@ -33,7 +33,7 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind], for ai, actual_kind in enumerate(actual_kinds): if actual_kind == nodes.ARG_POS: if fi < nformals: - if not formal_kinds[fi].is_star(): + if not nodes.is_star(formal_kinds[fi]): formal_to_actual[fi].append(ai) fi += 1 elif formal_kinds[fi] == nodes.ARG_STAR: @@ -55,14 +55,14 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind], # Assume that it is an iterable (if it isn't, there will be # an error later). while fi < nformals: - if formal_kinds[fi].is_named(star=True): + if nodes.is_named(formal_kinds[fi], star=True): break else: formal_to_actual[fi].append(ai) if formal_kinds[fi] == nodes.ARG_STAR: break fi += 1 - elif actual_kind.is_named(): + elif nodes.is_named(actual_kind): assert actual_names is not None, "Internal error: named kinds without names given" name = actual_names[ai] if name in formal_names: diff --git a/mypy/checker.py b/mypy/checker.py index 36751bbc9b15..95b9684bfe29 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -36,7 +36,7 @@ Instance, NoneType, strip_type, TypeType, TypeOfAny, UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, is_named_instance, union_items, TypeQuery, LiteralType, - is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, + is_optional_type, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType) from mypy.sametypes import is_same_type from mypy.messages import ( @@ -4512,11 +4512,11 @@ def has_no_custom_eq_checks(t: Type) -> bool: collection_type = operand_types[right_index] # We only try and narrow away 'None' for now - if not is_optional(item_type): + if not is_optional_type(item_type): continue collection_item_type = get_proper_type(builtin_item_type(collection_type)) - if collection_item_type is None or is_optional(collection_item_type): + if collection_item_type is None or is_optional_type(collection_item_type): continue if (isinstance(collection_item_type, Instance) and collection_item_type.type.fullname == 'builtins.object'): @@ -4904,7 +4904,7 @@ def refine_away_none_in_comparison(self, non_optional_types = [] for i in chain_indices: typ = operand_types[i] - if not is_optional(typ): + if not is_optional_type(typ): non_optional_types.append(typ) # Make sure we have a mixture of optional and non-optional types. @@ -4914,7 +4914,7 @@ def refine_away_none_in_comparison(self, if_map = {} for i in narrowable_operand_indices: expr_type = operand_types[i] - if not is_optional(expr_type): + if not is_optional_type(expr_type): continue if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): if_map[operands[i]] = remove_optional(expr_type) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e850744b5c71..690bac3daf91 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -19,7 +19,7 @@ TupleType, TypedDictType, Instance, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, is_named_instance, FunctionLike, ParamSpecType, - StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType, + StarType, is_optional_type, remove_optional, is_generic_instance, get_proper_type, ProperType, get_proper_types, flatten_nested_unions ) from mypy.nodes import ( @@ -34,6 +34,7 @@ TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode, ParamSpecExpr, ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, + is_required, is_positional, is_star, is_named ) from mypy.literals import literal from mypy import nodes @@ -1127,7 +1128,7 @@ def infer_arg_types_in_context( for i, actuals in enumerate(formal_to_actual): for ai in actuals: - if not arg_kinds[ai].is_star(): + if not is_star(arg_kinds[ai]): res[ai] = self.accept(args[ai], callee.arg_types[i]) # Fill in the rest of the argument types. @@ -1155,7 +1156,7 @@ def infer_function_type_arguments_using_context( # valid results. erased_ctx = replace_meta_vars(ctx, ErasedType()) ret_type = callable.ret_type - if is_optional(ret_type) and is_optional(ctx): + if is_optional_type(ret_type) and is_optional_type(ctx): # If both the context and the return type are optional, unwrap the optional, # since in 99% cases this is what a user expects. In other words, we replace # Optional[T] <: Optional[int] @@ -1389,16 +1390,16 @@ def check_argument_count(self, # Check for too many or few values for formals. for i, kind in enumerate(callee.arg_kinds): - if kind.is_required() and not formal_to_actual[i] and not is_unexpected_arg_error: + if is_required(kind) and not formal_to_actual[i] and not is_unexpected_arg_error: # No actual for a mandatory formal if messages: - if kind.is_positional(): + if is_positional(kind): messages.too_few_arguments(callee, context, actual_names) else: argname = callee.arg_names[i] or "?" messages.missing_named_argument(callee, context, argname) ok = False - elif not kind.is_star() and is_duplicate_mapping( + elif not is_star(kind) and is_duplicate_mapping( formal_to_actual[i], actual_types, actual_kinds): if (self.chk.in_checked_function() or isinstance(get_proper_type(actual_types[formal_to_actual[i][0]]), @@ -1406,7 +1407,7 @@ def check_argument_count(self, if messages: messages.duplicate_argument_value(callee, i, context) ok = False - elif (kind.is_named() and formal_to_actual[i] and + elif (is_named(kind) and formal_to_actual[i] and actual_kinds[formal_to_actual[i][0]] not in [nodes.ARG_NAMED, nodes.ARG_STAR2]): # Positional argument when expecting a keyword argument. if messages: @@ -1948,7 +1949,7 @@ def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, C for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)): if new_kind == target_kind: continue - elif new_kind.is_positional() and target_kind.is_positional(): + elif is_positional(new_kind) and is_positional(target_kind): new_kinds[i] = ARG_POS else: too_complex = True diff --git a/mypy/join.py b/mypy/join.py index 291a934e5943..d6e56b2a6607 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -14,7 +14,7 @@ is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype, is_protocol_implementation, find_member ) -from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT +from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT, is_named import mypy.typeops from mypy import state @@ -532,7 +532,7 @@ def combine_arg_names(t: CallableType, s: CallableType) -> List[Optional[str]]: for i in range(num_args): t_name = t.arg_names[i] s_name = s.arg_names[i] - if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named(): + if t_name == s_name or is_named(t.arg_kinds[i]) or is_named(s.arg_kinds[i]): new_names.append(t_name) else: new_names.append(None) diff --git a/mypy/messages.py b/mypy/messages.py index e3b12f49d980..a65a483cdbb6 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -31,7 +31,8 @@ TypeInfo, Context, MypyFile, FuncDef, reverse_builtin_aliases, ArgKind, ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode, - CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode, SYMBOL_FUNCBASE_TYPES + CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode, SYMBOL_FUNCBASE_TYPES, is_positional, + is_named, is_star, is_optional ) from mypy.operators import op_methods, op_methods_to_symbols from mypy.subtypes import ( @@ -1764,12 +1765,12 @@ def format(typ: Type) -> str: for arg_name, arg_type, arg_kind in zip( func.arg_names, func.arg_types, func.arg_kinds): if (arg_kind == ARG_POS and arg_name is None - or verbosity == 0 and arg_kind.is_positional()): + or verbosity == 0 and is_positional(arg_kind)): arg_strings.append(format(arg_type)) else: constructor = ARG_CONSTRUCTOR_NAMES[arg_kind] - if arg_kind.is_star() or arg_name is None: + if is_star(arg_kind) or arg_name is None: arg_strings.append("{}({})".format( constructor, format(arg_type))) @@ -1912,7 +1913,7 @@ def [T <: int] f(self, x: int, y: T) -> None for i in range(len(tp.arg_types)): if s: s += ', ' - if tp.arg_kinds[i].is_named() and not asterisk: + if is_named(tp.arg_kinds[i]) and not asterisk: s += '*, ' asterisk = True if tp.arg_kinds[i] == ARG_STAR: @@ -1924,7 +1925,7 @@ def [T <: int] f(self, x: int, y: T) -> None if name: s += name + ': ' s += format_type_bare(tp.arg_types[i]) - if tp.arg_kinds[i].is_optional(): + if is_optional(tp.arg_kinds[i]): s += ' = ...' # If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list diff --git a/mypy/nodes.py b/mypy/nodes.py index 1501c20514c0..7272550fb9e8 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1641,28 +1641,29 @@ class ArgKind(Enum): # In an argument list, keyword-only and also optional ARG_NAMED_OPT = 5 - def is_positional(self, star: bool = False) -> bool: - return ( - self == ARG_POS - or self == ARG_OPT - or (star and self == ARG_STAR) - ) - def is_named(self, star: bool = False) -> bool: - return ( - self == ARG_NAMED - or self == ARG_NAMED_OPT - or (star and self == ARG_STAR2) - ) +def is_positional(self: ArgKind, star: bool = False) -> bool: + return ( + self == ARG_POS + or self == ARG_OPT + or (star and self == ARG_STAR) + ) + +def is_named(self: ArgKind, star: bool = False) -> bool: + return ( + self == ARG_NAMED + or self == ARG_NAMED_OPT + or (star and self == ARG_STAR2) + ) - def is_required(self) -> bool: - return self == ARG_POS or self == ARG_NAMED +def is_required(self: ArgKind) -> bool: + return self == ARG_POS or self == ARG_NAMED - def is_optional(self) -> bool: - return self == ARG_OPT or self == ARG_NAMED_OPT +def is_optional(self: ArgKind) -> bool: + return self == ARG_OPT or self == ARG_NAMED_OPT - def is_star(self) -> bool: - return self == ARG_STAR or self == ARG_STAR2 +def is_star(self: ArgKind) -> bool: + return self == ARG_STAR or self == ARG_STAR2 ARG_POS: Final = ArgKind.ARG_POS diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 0bed61e3eeb1..fdea9636eb1f 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -6,7 +6,7 @@ from mypy.nodes import ( ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr, Context, Expression, JsonDict, NameExpr, RefExpr, - SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode + SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode, is_named ) from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( @@ -495,8 +495,8 @@ def _collect_field_args(expr: Expression, # field() only takes keyword arguments. args = {} for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds): - if not kind.is_named(): - if kind.is_named(star=True): + if not is_named(kind): + if is_named(kind, star=True): # This means that `field` is used with `**` unpacking, # the best we can do for now is not to fail. # TODO: we can infer what's inside `**` and try to collect it. diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index e52d478927e8..6b32e3920ff9 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -3,7 +3,7 @@ from typing_extensions import Final import mypy.plugin -from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var +from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var, is_positional from mypy.plugins.common import add_method_to_class from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType @@ -66,7 +66,7 @@ def _find_other_type(method: _MethodInfo) -> Type: cur_pos_arg = 0 other_arg = None for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types): - if arg_kind.is_positional(): + if is_positional(arg_kind): if cur_pos_arg == first_arg_pos: other_arg = arg_type break diff --git a/mypy/plugins/singledispatch.py b/mypy/plugins/singledispatch.py index 104faa38d1ce..fc31846e499d 100644 --- a/mypy/plugins/singledispatch.py +++ b/mypy/plugins/singledispatch.py @@ -1,7 +1,7 @@ from mypy.messages import format_type from mypy.plugins.common import add_method_to_class from mypy.nodes import ( - ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context + ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context, is_positional ) from mypy.subtypes import is_subtype from mypy.types import ( @@ -100,7 +100,7 @@ def create_singledispatch_function_callback(ctx: FunctionContext) -> Type: ) return ctx.default_return_type - elif not func_type.arg_kinds[0].is_positional(star=True): + elif not is_positional(func_type.arg_kinds[0], star=True): fail( ctx, 'First argument to singledispatch function must be a positional argument', diff --git a/mypy/strconv.py b/mypy/strconv.py index c63063af0776..92e5ec71fc9c 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -62,9 +62,9 @@ def func_helper(self, o: 'mypy.nodes.FuncItem') -> List[object]: extra: List[Tuple[str, List[mypy.nodes.Var]]] = [] for arg in o.arguments: kind: mypy.nodes.ArgKind = arg.kind - if kind.is_required(): + if mypy.nodes.is_required(kind): args.append(arg.variable) - elif kind.is_optional(): + elif mypy.nodes.is_optional(kind): assert arg.initializer is not None args.append(('default', [arg.variable, arg.initializer])) elif kind == mypy.nodes.ARG_STAR: diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 5518691b35fb..3b37aaa7cb1c 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -72,7 +72,7 @@ TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo, IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block, - Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, + Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, is_named ) from mypy.stubgenc import generate_stub_for_c_module from mypy.stubutil import ( @@ -650,7 +650,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False, if not isinstance(get_proper_type(annotated_type), AnyType): annotation = ": {}".format(self.print_annotation(annotated_type)) if arg_.initializer: - if kind.is_named() and not any(arg.startswith('*') for arg in args): + if is_named(kind) and not any(arg.startswith('*') for arg in args): args.append('*') if not annotation: typename = self.get_str_type_of_node(arg_.initializer, True, False) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 138f126c9d1a..b7c00fb25507 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -341,7 +341,7 @@ def _verify_arg_default_value( ) -> Iterator[str]: """Checks whether argument default values are compatible.""" if runtime_arg.default != inspect.Parameter.empty: - if stub_arg.kind.is_required(): + if nodes.is_required(stub_arg.kind): yield ( 'runtime argument "{}" has a default value but stub argument does not'.format( runtime_arg.name @@ -370,7 +370,7 @@ def _verify_arg_default_value( ) ) else: - if stub_arg.kind.is_optional(): + if nodes.is_optional(stub_arg.kind): yield ( 'stub argument "{}" has a default value but runtime argument does not'.format( stub_arg.variable.name @@ -413,7 +413,7 @@ def has_default(arg: Any) -> bool: if isinstance(arg, inspect.Parameter): return arg.default != inspect.Parameter.empty if isinstance(arg, nodes.Argument): - return arg.kind.is_optional() + return nodes.is_optional(arg.kind) raise AssertionError def get_desc(arg: Any) -> str: @@ -440,9 +440,9 @@ def from_funcitem(stub: nodes.FuncItem) -> "Signature[nodes.Argument]": stub_sig: Signature[nodes.Argument] = Signature() stub_args = maybe_strip_cls(stub.name, stub.arguments) for stub_arg in stub_args: - if stub_arg.kind.is_positional(): + if nodes.is_positional(stub_arg.kind): stub_sig.pos.append(stub_arg) - elif stub_arg.kind.is_named(): + elif nodes.is_named(stub_arg.kind): stub_sig.kwonly[stub_arg.variable.name] = stub_arg elif stub_arg.kind == nodes.ARG_STAR: stub_sig.varpos = stub_arg @@ -538,9 +538,9 @@ def get_kind(arg_name: str) -> nodes.ArgKind: initializer=None, kind=get_kind(arg_name), ) - if arg.kind.is_positional(): + if nodes.is_positional(arg.kind): sig.pos.append(arg) - elif arg.kind.is_named(): + elif nodes.is_named(arg.kind): sig.kwonly[arg.variable.name] = arg elif arg.kind == nodes.ARG_STAR: sig.varpos = arg diff --git a/mypy/subtypes.py b/mypy/subtypes.py index f9d27b7a1656..351b174280da 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -18,6 +18,7 @@ # import mypy.solve from mypy.nodes import ( FuncBase, Var, Decorator, OverloadedFuncDef, TypeInfo, CONTRAVARIANT, COVARIANT, + is_star, is_positional, is_optional ) from mypy.maptype import map_instance_to_supertype @@ -966,8 +967,8 @@ def _incompatible(left_arg: Optional[FormalArgument], i = right_star.pos assert i is not None - while i < len(left.arg_kinds) and left.arg_kinds[i].is_positional(): - if allow_partial_overlap and left.arg_kinds[i].is_optional(): + while i < len(left.arg_kinds) and is_positional(left.arg_kinds[i]): + if allow_partial_overlap and is_optional(left.arg_kinds[i]): break left_by_position = left.argument_by_position(i) @@ -986,7 +987,7 @@ def _incompatible(left_arg: Optional[FormalArgument], right_names = {name for name in right.arg_names if name is not None} left_only_names = set() for name, kind in zip(left.arg_names, left.arg_kinds): - if name is None or kind.is_star() or name in right_names: + if name is None or is_star(kind) or name in right_names: continue left_only_names.add(name) diff --git a/mypy/suggestions.py b/mypy/suggestions.py index 87b54814c637..3f13db0e4a0a 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -32,7 +32,7 @@ Type, AnyType, TypeOfAny, CallableType, UnionType, NoneType, Instance, TupleType, TypeVarType, FunctionLike, UninhabitedType, TypeStrVisitor, TypeTranslator, - is_optional, remove_optional, ProperType, get_proper_type, + is_optional_type, remove_optional, ProperType, get_proper_type, TypedDictType, TypeAliasType ) from mypy.build import State, Graph @@ -40,7 +40,7 @@ ArgKind, ARG_STAR, ARG_STAR2, FuncDef, MypyFile, SymbolTable, Decorator, RefExpr, SymbolNode, TypeInfo, Expression, ReturnStmt, CallExpr, - reverse_builtin_aliases, + reverse_builtin_aliases, is_named, is_star ) from mypy.server.update import FineGrainedBuildManager from mypy.util import split_target @@ -479,7 +479,7 @@ def format_args(self, arg = '*' + arg elif kind == ARG_STAR2: arg = '**' + arg - elif kind.is_named(): + elif is_named(kind): if name: arg = "%s=%s" % (name, arg) args.append(arg) @@ -712,7 +712,7 @@ def score_type(self, t: Type, arg_pos: bool) -> int: return 20 if any(has_any_type(x) for x in t.items): return 15 - if not is_optional(t): + if not is_optional_type(t): return 10 if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)): return 10 @@ -763,7 +763,7 @@ def any_score_callable(t: CallableType, is_method: bool, ignore_return: bool) -> def is_tricky_callable(t: CallableType) -> bool: """Is t a callable that we need to put a ... in for syntax reasons?""" - return t.is_ellipsis_args or any(k.is_star() or k.is_named() for k in t.arg_kinds) + return t.is_ellipsis_args or any(is_star(k) or is_named(k) for k in t.arg_kinds) class TypeFormatter(TypeStrVisitor): @@ -829,7 +829,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> str: return t.fallback.accept(self) def visit_union_type(self, t: UnionType) -> str: - if len(t.items) == 2 and is_optional(t): + if len(t.items) == 2 and is_optional_type(t): return "Optional[{}]".format(remove_optional(t).accept(self)) else: return super().visit_union_type(t) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d400c7e1ca69..06843349d97f 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -24,7 +24,7 @@ TypeInfo, Context, SymbolTableNode, Var, Expression, get_nongen_builtins, check_arg_names, check_arg_kinds, ArgKind, ARG_POS, ARG_NAMED, ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, TypeVarLikeExpr, ParamSpecExpr, - TypeAlias, PlaceholderNode, SYMBOL_FUNCBASE_TYPES, Decorator, MypyFile + TypeAlias, PlaceholderNode, SYMBOL_FUNCBASE_TYPES, Decorator, MypyFile, is_star ) from mypy.typetraverser import TypeTraverserVisitor from mypy.tvar_scope import TypeVarLikeScope @@ -782,7 +782,7 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], assert found.fullname is not None kind = ARG_KINDS_BY_CONSTRUCTOR[found.fullname] kinds.append(kind) - if arg.name is not None and kind.is_star(): + if arg.name is not None and is_star(kind): self.fail("{} arguments should not have names".format( arg.constructor), arg) return None diff --git a/mypy/types.py b/mypy/types.py index f7d2b4dd1931..9ed64b14c636 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -15,7 +15,8 @@ from mypy import state from mypy.nodes import ( INVARIANT, SymbolNode, FuncDef, - ArgKind, ARG_POS, ARG_STAR, ARG_STAR2, + ArgKind, ARG_POS, ARG_STAR, ARG_STAR2, is_positional, is_named, is_star, is_required, + is_optional ) from mypy.util import IdMapper from mypy.bogus_type import Bogus @@ -1166,7 +1167,7 @@ def max_possible_positional_args(self) -> int: This takes into account *arg and **kwargs but excludes keyword-only args.""" if self.is_var_arg or self.is_kw_arg: return sys.maxsize - return sum([kind.is_positional() for kind in self.arg_kinds]) + return sum([is_positional(kind) for kind in self.arg_kinds]) def formal_arguments(self, include_star_args: bool = False) -> List[FormalArgument]: """Yields the formal arguments corresponding to this callable, ignoring *arg and **kwargs. @@ -1180,12 +1181,12 @@ def formal_arguments(self, include_star_args: bool = False) -> List[FormalArgume done_with_positional = False for i in range(len(self.arg_types)): kind = self.arg_kinds[i] - if kind.is_named() or kind.is_star(): + if is_named(kind) or is_star(kind): done_with_positional = True - if not include_star_args and kind.is_star(): + if not include_star_args and is_star(kind): continue - required = kind.is_required() + required = is_required(kind) pos = None if done_with_positional else i arg = FormalArgument( self.arg_names[i], @@ -1203,13 +1204,13 @@ def argument_by_name(self, name: Optional[str]) -> Optional[FormalArgument]: for i, (arg_name, kind, typ) in enumerate( zip(self.arg_names, self.arg_kinds, self.arg_types)): # No more positional arguments after these. - if kind.is_named() or kind.is_star(): + if is_named(kind) or is_star(kind): seen_star = True - if kind.is_star(): + if is_star(kind): continue if arg_name == name: position = None if seen_star else i - return FormalArgument(name, position, typ, kind.is_required()) + return FormalArgument(name, position, typ, is_required(kind)) return self.try_synthesizing_arg_from_kwarg(name) def argument_by_position(self, position: Optional[int]) -> Optional[FormalArgument]: @@ -1222,7 +1223,7 @@ def argument_by_position(self, position: Optional[int]) -> Optional[FormalArgume self.arg_kinds[position], self.arg_types[position], ) - if kind.is_positional(): + if is_positional(kind): return FormalArgument(name, position, typ, kind == ARG_POS) else: return self.try_synthesizing_arg_from_vararg(position) @@ -2126,7 +2127,7 @@ def visit_callable_type(self, t: CallableType) -> str: for i in range(len(t.arg_types)): if s != '': s += ', ' - if t.arg_kinds[i].is_named() and not bare_asterisk: + if is_named(t.arg_kinds[i]) and not bare_asterisk: s += '*, ' bare_asterisk = True if t.arg_kinds[i] == ARG_STAR: @@ -2137,7 +2138,7 @@ def visit_callable_type(self, t: CallableType) -> str: if name: s += name + ': ' s += t.arg_types[i].accept(self) - if t.arg_kinds[i].is_optional(): + if is_optional(t.arg_kinds[i]): s += ' =' s = '({})'.format(s) @@ -2388,7 +2389,7 @@ def is_generic_instance(tp: Type) -> bool: return isinstance(tp, Instance) and bool(tp.args) -def is_optional(t: Type) -> bool: +def is_optional_type(t: Type) -> bool: t = get_proper_type(t) return isinstance(t, UnionType) and any(isinstance(get_proper_type(e), NoneType) for e in t.items) diff --git a/mypyc/ir/func_ir.py b/mypyc/ir/func_ir.py index 1426b0ecdf0f..30799074866b 100644 --- a/mypyc/ir/func_ir.py +++ b/mypyc/ir/func_ir.py @@ -3,7 +3,7 @@ from typing import List, Optional, Sequence from typing_extensions import Final -from mypy.nodes import FuncDef, Block, ArgKind, ARG_POS +from mypy.nodes import FuncDef, Block, ArgKind, ARG_POS, is_optional from mypyc.common import JsonDict, get_id_from_name, short_id_from_name from mypyc.ir.ops import ( @@ -28,7 +28,7 @@ def __init__( @property def optional(self) -> bool: - return self.kind.is_optional() + return is_optional(self.kind) def __repr__(self) -> str: return 'RuntimeArg(name=%s, type=%s, optional=%r, pos_only=%r)' % ( diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index bdd4ed992f2f..3d481f0cad16 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -16,7 +16,7 @@ from mypy.nodes import ( ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr, - FuncItem, LambdaExpr, SymbolNode, ArgKind, TypeInfo + FuncItem, LambdaExpr, SymbolNode, ArgKind, TypeInfo, is_named, is_optional, is_star ) from mypy.types import CallableType, get_proper_type @@ -667,7 +667,7 @@ def get_args(builder: IRBuilder, rt_args: Sequence[RuntimeArg], line: int) -> Ar args = [builder.read(builder.add_local_reg(var, type, is_arg=True), line) for var, type in fake_vars] arg_names = [arg.name - if arg.kind.is_named() or (arg.kind.is_optional() and not arg.pos_only) else None + if is_named(arg.kind) or (is_optional(arg.kind) and not arg.pos_only) else None for arg in rt_args] arg_kinds = [arg.kind for arg in rt_args] return ArgInfo(args, arg_names, arg_kinds) @@ -715,8 +715,8 @@ def f(builder: IRBuilder, x: object) -> int: ... # We can do a passthrough *args/**kwargs with a native call, but if the # args need to get distributed out to arguments, we just let python handle it if ( - any(kind.is_star() for kind in arg_kinds) - and any(not arg.kind.is_star() for arg in target.decl.sig.args) + any(is_star(kind) for kind in arg_kinds) + and any(not is_star(arg.kind) for arg in target.decl.sig.args) ): do_pycall = True diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 27419fcc7385..ef2468aab9b9 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -14,7 +14,9 @@ from typing_extensions import Final -from mypy.nodes import ArgKind, ARG_POS, ARG_STAR, ARG_STAR2 +from mypy.nodes import ( + ArgKind, ARG_POS, ARG_STAR, ARG_STAR2, is_optional, is_positional, is_named, is_star +) from mypy.operators import op_methods from mypy.types import AnyType, TypeOfAny from mypy.checkexpr import map_actuals_to_formals @@ -393,9 +395,9 @@ def _construct_varargs(self, line=line ) else: - nullable = kind.is_optional() - maybe_pos = kind.is_positional() and has_star - maybe_named = kind.is_named() or (kind.is_optional() and name and has_star2) + nullable = is_optional(kind) + maybe_pos = is_positional(kind) and has_star + maybe_named = is_named(kind) or (is_optional(kind) and name and has_star2) # If the argument is nullable, we need to create the # relevant args/kwargs objects so that we can @@ -530,7 +532,7 @@ def _py_vector_call(self, API should be used instead. """ # We can do this if all args are positional or named (no *args or **kwargs, not optional). - if arg_kinds is None or all(not kind.is_star() and not kind.is_optional() + if arg_kinds is None or all(not is_star(kind) and not is_optional(kind) for kind in arg_kinds): if arg_values: # Create a C array containing all arguments as boxed values. @@ -602,7 +604,7 @@ def _py_vector_method_call(self, Return the return value if successful. Return None if a non-vectorcall API should be used instead. """ - if arg_kinds is None or all(not kind.is_star() and not kind.is_optional() + if arg_kinds is None or all(not is_star(kind) and not is_optional(kind) for kind in arg_kinds): method_name_reg = self.load_str(method_name) array = Register(RArray(object_rprimitive, len(arg_values) + 1)) @@ -667,7 +669,7 @@ def native_args_to_positional(self, has_star = has_star2 = False star_arg_entries = [] for lst, arg in zip(formal_to_actual, sig.args): - if arg.kind.is_star(): + if is_star(arg.kind): star_arg_entries.extend([(args[i], arg_kinds[i], arg_names[i]) for i in lst]) has_star = has_star or arg.kind == ARG_STAR has_star2 = has_star2 or arg.kind == ARG_STAR2 @@ -692,7 +694,7 @@ def native_args_to_positional(self, else: base_arg = args[lst[0]] - if arg_kinds[lst[0]].is_optional(): + if is_optional(arg_kinds[lst[0]]): output_arg = self.coerce_nullable(base_arg, arg.type, line) else: output_arg = self.coerce(base_arg, arg.type, line) @@ -711,7 +713,7 @@ def gen_method_call(self, arg_names: Optional[List[Optional[str]]] = None) -> Value: """Generate either a native or Python method call.""" # If we have *args, then fallback to Python method call. - if arg_kinds is not None and any(kind.is_star() for kind in arg_kinds): + if arg_kinds is not None and any(is_star(kind) for kind in arg_kinds): return self.py_method_call(base, name, arg_values, base.line, arg_kinds, arg_names) # If the base type is one of ours, do a MethodCall From 395dd62f174664ba1ac5e49aa044f9dd656a73ef Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Sat, 13 Nov 2021 12:55:12 +0000 Subject: [PATCH 2/2] Clean up + fix another call site --- mypy/nodes.py | 34 +++++++++++++++++++--------------- mypy/semanal.py | 5 ++--- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 7272550fb9e8..0a059d665d62 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1625,9 +1625,9 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_member_expr(self) -# Kinds of arguments @unique class ArgKind(Enum): + """Kinds of arguments""" # Positional argument ARG_POS = 0 # Positional, optional argument (functions only, not calls) @@ -1642,28 +1642,32 @@ class ArgKind(Enum): ARG_NAMED_OPT = 5 -def is_positional(self: ArgKind, star: bool = False) -> bool: +def is_positional(kind: ArgKind, star: bool = False) -> bool: return ( - self == ARG_POS - or self == ARG_OPT - or (star and self == ARG_STAR) + kind == ARG_POS + or kind == ARG_OPT + or (star and kind == ARG_STAR) ) -def is_named(self: ArgKind, star: bool = False) -> bool: + +def is_named(kind: ArgKind, star: bool = False) -> bool: return ( - self == ARG_NAMED - or self == ARG_NAMED_OPT - or (star and self == ARG_STAR2) + kind == ARG_NAMED + or kind == ARG_NAMED_OPT + or (star and kind == ARG_STAR2) ) -def is_required(self: ArgKind) -> bool: - return self == ARG_POS or self == ARG_NAMED -def is_optional(self: ArgKind) -> bool: - return self == ARG_OPT or self == ARG_NAMED_OPT +def is_required(kind: ArgKind) -> bool: + return kind == ARG_POS or kind == ARG_NAMED + + +def is_optional(kind: ArgKind) -> bool: + return kind == ARG_OPT or kind == ARG_NAMED_OPT + -def is_star(self: ArgKind) -> bool: - return self == ARG_STAR or self == ARG_STAR2 +def is_star(kind: ArgKind) -> bool: + return kind == ARG_STAR or kind == ARG_STAR2 ARG_POS: Final = ArgKind.ARG_POS diff --git a/mypy/semanal.py b/mypy/semanal.py index f58f533a268f..8b65bfa56d80 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -76,8 +76,7 @@ get_nongen_builtins, get_member_expr_fullname, REVEAL_TYPE, REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_source_versions, EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, - ParamSpecExpr, EllipsisExpr, - FuncBase, implicit_module_attrs, + ParamSpecExpr, EllipsisExpr, FuncBase, implicit_module_attrs, is_named ) from mypy.tvar_scope import TypeVarLikeScope from mypy.typevars import fill_typevars @@ -3142,7 +3141,7 @@ def process_typevar_parameters(self, args: List[Expression], contravariant = False upper_bound: Type = self.object_type() for param_value, param_name, param_kind in zip(args, names, kinds): - if not param_kind.is_named(): + if not is_named(param_kind): self.fail("Unexpected argument to TypeVar()", context) return None if param_name == 'covariant':