diff --git a/flake8_trio.py b/flake8_trio.py index c55f803b..4a43d0a4 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -9,24 +9,14 @@ Pairs well with flake8-async and flake8-bugbear. """ +from __future__ import annotations + import argparse import ast import tokenize from argparse import Namespace from fnmatch import fnmatch -from typing import ( - Any, - Dict, - Iterable, - List, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, -) +from typing import Any, Iterable, NamedTuple, Sequence, Union, cast from flake8.options.manager import OptionManager @@ -95,7 +85,7 @@ def __eq__(self, other: Any) -> bool: def get_matching_call( node: ast.AST, *names: str, base: str = "trio" -) -> Optional[Tuple[ast.Call, str]]: +) -> tuple[ast.Call, str] | None: if ( isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) @@ -150,7 +140,7 @@ def __repr__(self) -> str: class Flake8TrioVisitor(ast.NodeVisitor): def __init__(self, options: Namespace): super().__init__() - self._problems: List[Error] = [] + self._problems: list[Error] = [] self.suppress_errors = False self.options = options @@ -160,9 +150,7 @@ def run(cls, tree: ast.AST, options: Namespace) -> Iterable[Error]: visitor.visit(tree) yield from visitor._problems - def visit_nodes( - self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False - ): + def visit_nodes(self, *nodes: ast.AST | Iterable[ast.AST], generic: bool = False): if generic: visit = self.generic_visit else: @@ -178,10 +166,10 @@ def error(self, error: str, node: HasLineCol, *args: object): if not self.suppress_errors: self._problems.append(Error(error, node.lineno, node.col_offset, *args)) - def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]: + def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]: if not attrs: attrs = tuple(self.__dict__.keys()) - res: Dict[str, Any] = {} + res: dict[str, Any] = {} for attr in attrs: if attr == "_problems": continue @@ -191,7 +179,7 @@ def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]: res[attr] = value return res - def set_state(self, attrs: Dict[str, Any], copy: bool = False): + def set_state(self, attrs: dict[str, Any], copy: bool = False): for attr, value in attrs.items(): if copy and hasattr(value, "copy"): value = value.copy() @@ -203,7 +191,7 @@ def walk(self, *body: ast.AST) -> Iterable[ast.AST]: # ignores module and only checks the unqualified name of the decorator -def has_decorator(decorator_list: List[ast.expr], *names: str): +def has_decorator(decorator_list: list[ast.expr], *names: str): for dec in decorator_list: if (isinstance(dec, ast.Name) and dec.id in names) or ( isinstance(dec, ast.Attribute) and dec.attr in names @@ -214,8 +202,8 @@ def has_decorator(decorator_list: List[ast.expr], *names: str): # matches the fully qualified name against fnmatch pattern # used to match decorators and methods to user-supplied patterns -def fnmatch_qualified_name(name_list: List[ast.expr], *patterns: str): - def construct_name(expr: ast.expr) -> Optional[str]: +def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str): + def construct_name(expr: ast.expr) -> str | None: if isinstance(expr, ast.Call): expr = expr.func if isinstance(expr, ast.Name): @@ -257,13 +245,13 @@ def __init__(self, *args: Any, **kwargs: Any): self._safe_decorator = False # 111 - self._context_managers: List[VisitorMiscChecks.TrioContextManager] = [] - self._nursery_call: Optional[VisitorMiscChecks.NurseryCall] = None + self._context_managers: list[VisitorMiscChecks.TrioContextManager] = [] + self._nursery_call: VisitorMiscChecks.NurseryCall | None = None self.defaults = self.get_state(copy=True) # ---- 100, 101, 111, 112 ---- - def visit_With(self, node: Union[ast.With, ast.AsyncWith]): + def visit_With(self, node: ast.With | ast.AsyncWith): self.check_for_trio100(node) self.check_for_trio112(node) @@ -305,7 +293,7 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]): visit_AsyncWith = visit_With # ---- 100 ---- - def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]): + def check_for_trio100(self, node: ast.With | ast.AsyncWith): # Context manager with no `await trio.X` call within for item in (i.context_expr for i in node.items): call = get_matching_call(item, *cancel_scope_names) @@ -316,7 +304,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]): self.error("TRIO100", item, f"trio.{call[1]}") # ---- 101 ---- - def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): + def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef): outer = self.get_state() self.set_state(self.defaults, copy=True) @@ -433,7 +421,7 @@ def visit_Name(self, node: ast.Name): # if with has a withitem `trio.open_nursery() as `, # and the body is only a single expression .start[_soon](), # and does not pass as a parameter to the expression - def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]): + def check_for_trio112(self, node: ast.With | ast.AsyncWith): # body is single expression if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr): return @@ -462,8 +450,8 @@ def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]): # used in 102, 103 and 104 -def critical_except(node: ast.ExceptHandler) -> Optional[Statement]: - def has_exception(node: Optional[ast.expr]) -> str: +def critical_except(node: ast.ExceptHandler) -> Statement | None: + def has_exception(node: ast.expr | None) -> str: if isinstance(node, ast.Name) and node.id == "BaseException": return "BaseException" if ( @@ -496,7 +484,7 @@ class TrioScope: def __init__(self, node: ast.Call, funcname: str): self.node = node self.funcname = funcname - self.variable_name: Optional[str] = None + self.variable_name: str | None = None self.shielded: bool = False self.has_timeout: bool = True @@ -514,14 +502,14 @@ def __init__(self, node: ast.Call, funcname: str): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - self._critical_scope: Optional[Statement] = None - self._trio_context_managers: List[Visitor102.TrioScope] = [] + self._critical_scope: Statement | None = None + self._trio_context_managers: list[Visitor102.TrioScope] = [] # if we're inside a finally, and we're not inside a scope that doesn't have # both a timeout and shield def visit_Await( self, - node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith], + node: ast.Await | ast.AsyncFor | ast.AsyncWith, visit_children: bool = True, ): if self._critical_scope is not None and not any( @@ -533,7 +521,7 @@ def visit_Await( visit_AsyncFor = visit_Await - def visit_With(self, node: Union[ast.With, ast.AsyncWith]): + def visit_With(self, node: ast.With | ast.AsyncWith): has_context_manager = False # Check for a `with trio.` @@ -564,7 +552,7 @@ def visit_AsyncWith(self, node: ast.AsyncWith): def critical_visit( self, - node: Union[ast.ExceptHandler, Iterable[ast.AST]], + node: ast.ExceptHandler | Iterable[ast.AST], block: Statement, generic: bool = False, ): @@ -612,7 +600,7 @@ def visit_Assign(self, node: ast.Assign): class Visitor103_104(Flake8TrioVisitor): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - self.except_name: Optional[str] = "" + self.except_name: str | None = "" self.unraised: bool = False self.unraised_break: bool = False self.unraised_continue: bool = False @@ -659,7 +647,7 @@ def visit_Raise(self, node: ast.Raise): self.generic_visit(node) - def visit_Return(self, node: Union[ast.Return, ast.Yield]): + def visit_Return(self, node: ast.Return | ast.Yield): if self.unraised: # Error: must re-raise self.error("TRIO104", node) @@ -709,7 +697,7 @@ def visit_If(self, node: ast.If): # else always raises, and # always raise before break # or body always raises (before break) and is guaranteed to run at least once - def visit_For(self, node: Union[ast.For, ast.While]): + def visit_For(self, node: ast.For | ast.While): if not self.unraised: self.generic_visit(node) return @@ -829,7 +817,7 @@ class Visitor105(Flake8TrioVisitor): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # keep a node stack so we can check whether calls are awaited - self.node_stack: List[ast.AST] = [] + self.node_stack: list[ast.AST] = [] def visit(self, node: ast.AST): self.node_stack.append(node) @@ -854,7 +842,7 @@ def visit_Call(self, node: ast.Call): self.generic_visit(node) -def empty_body(body: List[ast.stmt]) -> bool: +def empty_body(body: list[ast.stmt]) -> bool: # Does the function body consist solely of `pass`, `...`, and (doc)string literals? return all( isinstance(stmt, ast.Pass) @@ -874,9 +862,9 @@ def __init__(self, *args: Any, **kwargs: Any): self.safe_decorator = False self.async_function = False - self.uncheckpointed_statements: Set[Statement] = set() - self.uncheckpointed_before_continue: Set[Statement] = set() - self.uncheckpointed_before_break: Set[Statement] = set() + self.uncheckpointed_statements: set[Statement] = set() + self.uncheckpointed_before_continue: set[Statement] = set() + self.uncheckpointed_before_break: set[Statement] = set() self.default = self.get_state() @@ -913,7 +901,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): self.set_state(outer) # error if function exits or returns with uncheckpointed statements - def check_function_exit(self, node: Union[ast.Return, ast.AsyncFunctionDef]): + def check_function_exit(self, node: ast.Return | ast.AsyncFunctionDef): for statement in self.uncheckpointed_statements: self.error( "TRIO108" if self.has_yield else "TRIO107", @@ -937,7 +925,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef): self.set_state(outer) # checkpoint functions - def visit_Await(self, node: Union[ast.Await, ast.Raise]): + def visit_Await(self, node: ast.Await | ast.Raise): # the expression being awaited is not checkpointed # so only set checkpoint after the await node self.generic_visit(node) @@ -994,7 +982,7 @@ def visit_Try(self, node: ast.Try): try_checkpoint = self.uncheckpointed_statements # check that all except handlers checkpoint (await or most likely raise) - except_uncheckpointed_statements: Set[Statement] = set() + except_uncheckpointed_statements: set[Statement] = set() for handler in node.handlers: # enter with worst case of try @@ -1032,7 +1020,7 @@ def visit_Try(self, node: ast.Try): self.uncheckpointed_statements.difference_update(added) # valid checkpoint if both body and orelse checkpoint - def visit_If(self, node: Union[ast.If, ast.IfExp]): + def visit_If(self, node: ast.If | ast.IfExp): # visit condition self.visit_nodes(node.test) outer = self.uncheckpointed_statements.copy() @@ -1055,7 +1043,7 @@ def visit_If(self, node: Union[ast.If, ast.IfExp]): # after completing all of loop body, and after any continues. # yield in else have same requirement # state after the loop same as above, and in addition the state at any break - def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]): + def visit_loop(self, node: ast.While | ast.For | ast.AsyncFor): # visit condition infinite_loop = False if isinstance(node, ast.While): @@ -1307,8 +1295,8 @@ def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, - values: Optional[Sequence[Any]], - option_string: Optional[str] = None, + values: Sequence[Any] | None, + option_string: str | None = None, ): assert values is not None assert option_string is not None @@ -1327,7 +1315,7 @@ def __init__(self, tree: ast.AST): self._tree = tree @classmethod - def from_filename(cls, filename: str) -> "Plugin": + def from_filename(cls, filename: str) -> Plugin: with tokenize.open(filename) as f: source = f.read() return cls(ast.parse(source)) diff --git a/tests/conftest.py b/tests/conftest.py index af3a52de..0bc2c0da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations import pytest @@ -15,7 +15,7 @@ def pytest_configure(config: pytest.Config): ) -def pytest_collection_modifyitems(config: pytest.Config, items: List[pytest.Item]): +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]): if config.getoption("--runfuzz"): # --runfuzz given in cli: do not skip fuzz tests return diff --git a/tests/test_changelog_and_version.py b/tests/test_changelog_and_version.py index e925a287..2c8c941d 100644 --- a/tests/test_changelog_and_version.py +++ b/tests/test_changelog_and_version.py @@ -1,9 +1,11 @@ """Tests for flake8-trio package metadata.""" +from __future__ import annotations + import os import re import unittest from pathlib import Path -from typing import Dict, Iterable, NamedTuple, Set +from typing import Iterable, NamedTuple from test_flake8_trio import trio_test_files_regex @@ -51,7 +53,7 @@ def test_version_increments_are_correct(): class test_messages_documented(unittest.TestCase): def runTest(self): - documented_errors: Dict[str, Set[str]] = {} + documented_errors: dict[str, set[str]] = {} for filename in ( "CHANGELOG.md", "README.md", @@ -71,8 +73,8 @@ def runTest(self): if re.match(trio_test_files_regex, f) } - unique_errors: Dict[str, Set[str]] = {} - missing_errors: Dict[str, Set[str]] = {} + unique_errors: dict[str, set[str]] = {} + missing_errors: dict[str, set[str]] = {} for key, codes in documented_errors.items(): unique_errors[key] = codes.copy() missing_errors[key] = set() diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 90937b8d..9399989d 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import ast import sys from argparse import Namespace -from typing import Tuple import pytest from flake8.main.application import Application @@ -18,7 +19,7 @@ def dec_list(*decorators: str) -> ast.Module: return tree -def wrap(decorators: Tuple[str, ...], decs2: str) -> bool: +def wrap(decorators: tuple[str, ...], decs2: str) -> bool: tree = dec_list(*decorators) assert isinstance(tree.body[0], ast.AsyncFunctionDef) return fnmatch_qualified_name(tree.body[0].decorator_list, decs2) diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 29489272..f529381a 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ast import copy import itertools @@ -8,7 +10,7 @@ import tokenize import unittest from pathlib import Path -from typing import DefaultDict, Iterable, List, Sequence, Tuple, Type +from typing import DefaultDict, Iterable, Sequence import pytest from flake8 import __version_info__ as flake8_version_info @@ -22,7 +24,7 @@ trio_test_files_regex = re.compile(r"trio\d\d\d(_py.*)?.py") -test_files: List[Tuple[str, str]] = sorted( +test_files: list[tuple[str, str]] = sorted( (os.path.splitext(f)[0].upper(), f) for f in os.listdir("tests") if re.match(trio_test_files_regex, f) @@ -62,7 +64,7 @@ def test_eval(test: str, path: str): assert test in Error_codes.keys(), "error code not defined in flake8_trio.py" include = [test] - expected: List[Error] = [] + expected: list[Error] = [] with open(os.path.join("tests", path), encoding="utf-8") as file: lines = file.readlines() @@ -134,7 +136,7 @@ def visit_Await(self, node: ast.Await): newnode = self.generic_visit(node.value) return newnode - def replace_async(self, node: ast.AST, target: Type[ast.AST]) -> ast.AST: + def replace_async(self, node: ast.AST, target: type[ast.AST]) -> ast.AST: node = self.generic_visit(node) newnode = target() newnode.__dict__ = node.__dict__ @@ -188,8 +190,8 @@ def assert_expected_errors(plugin: Plugin, include: Iterable[str], *expected: Er def print_first_diff(errors: Sequence[Error], expected: Sequence[Error]): - first_error_line: List[Error] = [] - first_expected_line: List[Error] = [] + first_error_line: list[Error] = [] + first_expected_line: list[Error] = [] for err, exp in zip(errors, expected): if err == exp: continue diff --git a/typings/flake8/__init__.pyi b/typings/flake8/__init__.pyi index bf8f2de6..2745afa7 100644 --- a/typings/flake8/__init__.pyi +++ b/typings/flake8/__init__.pyi @@ -5,8 +5,6 @@ This type stub file was generated by pyright. from __future__ import annotations import logging -import sys -from typing import Dict, Tuple """Top-level module for Flake8. @@ -21,8 +19,8 @@ This module """ LOG: logging.Logger = ... __version__: str = ... -__version_info__: Tuple[int, ...] = ... -_VERBOSITY_TO_LOG_LEVEL: Dict[int, int] = ... +__version_info__: tuple[int, ...] = ... +_VERBOSITY_TO_LOG_LEVEL: dict[int, int] = ... LOG_FORMAT: str = ... def configure_logging( diff --git a/typings/flake8/options/manager.pyi b/typings/flake8/options/manager.pyi index 89e2d324..540a60f3 100644 --- a/typings/flake8/options/manager.pyi +++ b/typings/flake8/options/manager.pyi @@ -3,18 +3,14 @@ This type stub file was generated by pyright. Generated for flake8 5, so OptionManager signature is incorrect for flake8 6 """ +from __future__ import annotations + import argparse from typing import ( Any, Callable, - Dict, - List, Mapping, - Optional, Sequence, - Tuple, - Type, - Union, ) from flake8.plugins.finder import Plugins @@ -22,7 +18,7 @@ from flake8.plugins.finder import Plugins """Option handling and Option management logic.""" LOG = ... _ARG = ... -_optparse_callable_map: Dict[str, Union[Type[Any], _ARG]] = ... +_optparse_callable_map: dict[str, type[Any] | _ARG] = ... class _CallbackAction(argparse.Action): """Shim for optparse-style callback actions.""" @@ -32,15 +28,15 @@ class _CallbackAction(argparse.Action): *args: Any, callback: Callable[..., Any], callback_args: Sequence[Any] = ..., - callback_kwargs: Optional[Dict[str, Any]] = ..., + callback_kwargs: dict[str, Any] | None = ..., **kwargs: Any, ) -> None: ... def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, - values: Optional[Union[Sequence[str], str]], - option_string: Optional[str] = ..., + values: Sequence[str] | str | None, + option_string: str | None = ..., ) -> None: ... class Option: @@ -48,21 +44,21 @@ class Option: def __init__( self, - short_option_name: Union[str, _ARG] = ..., - long_option_name: Union[str, _ARG] = ..., - action: Union[str, Type[argparse.Action], _ARG] = ..., - default: Union[Any, _ARG] = ..., - type: Union[str, Callable[..., Any], _ARG] = ..., - dest: Union[str, _ARG] = ..., - nargs: Union[int, str, _ARG] = ..., - const: Union[Any, _ARG] = ..., - choices: Union[Sequence[Any], _ARG] = ..., - help: Union[str, _ARG] = ..., - metavar: Union[str, _ARG] = ..., - callback: Union[Callable[..., Any], _ARG] = ..., - callback_args: Union[Sequence[Any], _ARG] = ..., - callback_kwargs: Union[Mapping[str, Any], _ARG] = ..., - required: Union[bool, _ARG] = ..., + short_option_name: str | _ARG = ..., + long_option_name: str | _ARG = ..., + action: str | type[argparse.Action] | _ARG = ..., + default: Any | _ARG = ..., + type: str | Callable[..., Any] | _ARG = ..., + dest: str | _ARG = ..., + nargs: int | str | _ARG = ..., + const: Any | _ARG = ..., + choices: Sequence[Any] | _ARG = ..., + help: str | _ARG = ..., + metavar: str | _ARG = ..., + callback: Callable[..., Any] | _ARG = ..., + callback_args: Sequence[Any] | _ARG = ..., + callback_kwargs: Mapping[str, Any] | _ARG = ..., + required: bool | _ARG = ..., parse_from_config: bool = ..., comma_separated_list: bool = ..., normalize_paths: bool = ..., @@ -130,14 +126,14 @@ class Option: """ ... @property - def filtered_option_kwargs(self) -> Dict[str, Any]: + def filtered_option_kwargs(self) -> dict[str, Any]: """Return any actually-specified arguments.""" ... def __repr__(self) -> str: ... def normalize(self, value: Any, *normalize_args: str) -> Any: """Normalize the value based on the option configuration.""" ... - def to_argparse(self) -> Tuple[List[str], Dict[str, Any]]: + def to_argparse(self) -> tuple[list[str], dict[str, Any]]: """Convert a Flake8 Option to argparse ``add_argument`` arguments.""" ... @@ -149,7 +145,7 @@ class OptionManager: *, version: str, plugin_versions: str, - parents: List[argparse.ArgumentParser], + parents: list[argparse.ArgumentParser], ) -> None: """Initialize an instance of an OptionManager. @@ -197,8 +193,8 @@ class OptionManager: ... def parse_args( self, - args: Optional[Sequence[str]] = ..., - values: Optional[argparse.Namespace] = ..., + args: Sequence[str] | None = ..., + values: argparse.Namespace | None = ..., ) -> argparse.Namespace: """Proxy to calling the OptionParser's parse_args method.""" ...