From 1b7e70f89ed9703bacd32ee6f2c13d8f9a886867 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 7 Dec 2022 11:29:45 +0100
Subject: [PATCH] import annotations from __future__ and update annotations
---
flake8_trio.py | 96 +++++++++++++----------------
tests/conftest.py | 4 +-
tests/test_changelog_and_version.py | 10 +--
tests/test_decorator.py | 5 +-
tests/test_flake8_trio.py | 14 +++--
typings/flake8/__init__.pyi | 6 +-
typings/flake8/options/manager.pyi | 56 ++++++++---------
7 files changed, 89 insertions(+), 102 deletions(-)
diff --git a/flake8_trio.py b/flake8_trio.py
index c55f803b..4a43d0a4 100644
--- a/flake8_trio.py
+++ b/flake8_trio.py
@@ -9,24 +9,14 @@
Pairs well with flake8-async and flake8-bugbear.
"""
+from __future__ import annotations
+
import argparse
import ast
import tokenize
from argparse import Namespace
from fnmatch import fnmatch
-from typing import (
- Any,
- Dict,
- Iterable,
- List,
- NamedTuple,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
- cast,
-)
+from typing import Any, Iterable, NamedTuple, Sequence, Union, cast
from flake8.options.manager import OptionManager
@@ -95,7 +85,7 @@ def __eq__(self, other: Any) -> bool:
def get_matching_call(
node: ast.AST, *names: str, base: str = "trio"
-) -> Optional[Tuple[ast.Call, str]]:
+) -> tuple[ast.Call, str] | None:
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
@@ -150,7 +140,7 @@ def __repr__(self) -> str:
class Flake8TrioVisitor(ast.NodeVisitor):
def __init__(self, options: Namespace):
super().__init__()
- self._problems: List[Error] = []
+ self._problems: list[Error] = []
self.suppress_errors = False
self.options = options
@@ -160,9 +150,7 @@ def run(cls, tree: ast.AST, options: Namespace) -> Iterable[Error]:
visitor.visit(tree)
yield from visitor._problems
- def visit_nodes(
- self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False
- ):
+ def visit_nodes(self, *nodes: ast.AST | Iterable[ast.AST], generic: bool = False):
if generic:
visit = self.generic_visit
else:
@@ -178,10 +166,10 @@ def error(self, error: str, node: HasLineCol, *args: object):
if not self.suppress_errors:
self._problems.append(Error(error, node.lineno, node.col_offset, *args))
- def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
+ def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
if not attrs:
attrs = tuple(self.__dict__.keys())
- res: Dict[str, Any] = {}
+ res: dict[str, Any] = {}
for attr in attrs:
if attr == "_problems":
continue
@@ -191,7 +179,7 @@ def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
res[attr] = value
return res
- def set_state(self, attrs: Dict[str, Any], copy: bool = False):
+ def set_state(self, attrs: dict[str, Any], copy: bool = False):
for attr, value in attrs.items():
if copy and hasattr(value, "copy"):
value = value.copy()
@@ -203,7 +191,7 @@ def walk(self, *body: ast.AST) -> Iterable[ast.AST]:
# ignores module and only checks the unqualified name of the decorator
-def has_decorator(decorator_list: List[ast.expr], *names: str):
+def has_decorator(decorator_list: list[ast.expr], *names: str):
for dec in decorator_list:
if (isinstance(dec, ast.Name) and dec.id in names) or (
isinstance(dec, ast.Attribute) and dec.attr in names
@@ -214,8 +202,8 @@ def has_decorator(decorator_list: List[ast.expr], *names: str):
# matches the fully qualified name against fnmatch pattern
# used to match decorators and methods to user-supplied patterns
-def fnmatch_qualified_name(name_list: List[ast.expr], *patterns: str):
- def construct_name(expr: ast.expr) -> Optional[str]:
+def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str):
+ def construct_name(expr: ast.expr) -> str | None:
if isinstance(expr, ast.Call):
expr = expr.func
if isinstance(expr, ast.Name):
@@ -257,13 +245,13 @@ def __init__(self, *args: Any, **kwargs: Any):
self._safe_decorator = False
# 111
- self._context_managers: List[VisitorMiscChecks.TrioContextManager] = []
- self._nursery_call: Optional[VisitorMiscChecks.NurseryCall] = None
+ self._context_managers: list[VisitorMiscChecks.TrioContextManager] = []
+ self._nursery_call: VisitorMiscChecks.NurseryCall | None = None
self.defaults = self.get_state(copy=True)
# ---- 100, 101, 111, 112 ----
- def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
+ def visit_With(self, node: ast.With | ast.AsyncWith):
self.check_for_trio100(node)
self.check_for_trio112(node)
@@ -305,7 +293,7 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
visit_AsyncWith = visit_With
# ---- 100 ----
- def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
+ def check_for_trio100(self, node: ast.With | ast.AsyncWith):
# Context manager with no `await trio.X` call within
for item in (i.context_expr for i in node.items):
call = get_matching_call(item, *cancel_scope_names)
@@ -316,7 +304,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
self.error("TRIO100", item, f"trio.{call[1]}")
# ---- 101 ----
- def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
+ def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
outer = self.get_state()
self.set_state(self.defaults, copy=True)
@@ -433,7 +421,7 @@ def visit_Name(self, node: ast.Name):
# if with has a withitem `trio.open_nursery() as `,
# and the body is only a single expression .start[_soon](),
# and does not pass as a parameter to the expression
- def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]):
+ def check_for_trio112(self, node: ast.With | ast.AsyncWith):
# body is single expression
if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr):
return
@@ -462,8 +450,8 @@ def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]):
# used in 102, 103 and 104
-def critical_except(node: ast.ExceptHandler) -> Optional[Statement]:
- def has_exception(node: Optional[ast.expr]) -> str:
+def critical_except(node: ast.ExceptHandler) -> Statement | None:
+ def has_exception(node: ast.expr | None) -> str:
if isinstance(node, ast.Name) and node.id == "BaseException":
return "BaseException"
if (
@@ -496,7 +484,7 @@ class TrioScope:
def __init__(self, node: ast.Call, funcname: str):
self.node = node
self.funcname = funcname
- self.variable_name: Optional[str] = None
+ self.variable_name: str | None = None
self.shielded: bool = False
self.has_timeout: bool = True
@@ -514,14 +502,14 @@ def __init__(self, node: ast.Call, funcname: str):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
- self._critical_scope: Optional[Statement] = None
- self._trio_context_managers: List[Visitor102.TrioScope] = []
+ self._critical_scope: Statement | None = None
+ self._trio_context_managers: list[Visitor102.TrioScope] = []
# if we're inside a finally, and we're not inside a scope that doesn't have
# both a timeout and shield
def visit_Await(
self,
- node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith],
+ node: ast.Await | ast.AsyncFor | ast.AsyncWith,
visit_children: bool = True,
):
if self._critical_scope is not None and not any(
@@ -533,7 +521,7 @@ def visit_Await(
visit_AsyncFor = visit_Await
- def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
+ def visit_With(self, node: ast.With | ast.AsyncWith):
has_context_manager = False
# Check for a `with trio.`
@@ -564,7 +552,7 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
def critical_visit(
self,
- node: Union[ast.ExceptHandler, Iterable[ast.AST]],
+ node: ast.ExceptHandler | Iterable[ast.AST],
block: Statement,
generic: bool = False,
):
@@ -612,7 +600,7 @@ def visit_Assign(self, node: ast.Assign):
class Visitor103_104(Flake8TrioVisitor):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
- self.except_name: Optional[str] = ""
+ self.except_name: str | None = ""
self.unraised: bool = False
self.unraised_break: bool = False
self.unraised_continue: bool = False
@@ -659,7 +647,7 @@ def visit_Raise(self, node: ast.Raise):
self.generic_visit(node)
- def visit_Return(self, node: Union[ast.Return, ast.Yield]):
+ def visit_Return(self, node: ast.Return | ast.Yield):
if self.unraised:
# Error: must re-raise
self.error("TRIO104", node)
@@ -709,7 +697,7 @@ def visit_If(self, node: ast.If):
# else always raises, and
# always raise before break
# or body always raises (before break) and is guaranteed to run at least once
- def visit_For(self, node: Union[ast.For, ast.While]):
+ def visit_For(self, node: ast.For | ast.While):
if not self.unraised:
self.generic_visit(node)
return
@@ -829,7 +817,7 @@ class Visitor105(Flake8TrioVisitor):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
# keep a node stack so we can check whether calls are awaited
- self.node_stack: List[ast.AST] = []
+ self.node_stack: list[ast.AST] = []
def visit(self, node: ast.AST):
self.node_stack.append(node)
@@ -854,7 +842,7 @@ def visit_Call(self, node: ast.Call):
self.generic_visit(node)
-def empty_body(body: List[ast.stmt]) -> bool:
+def empty_body(body: list[ast.stmt]) -> bool:
# Does the function body consist solely of `pass`, `...`, and (doc)string literals?
return all(
isinstance(stmt, ast.Pass)
@@ -874,9 +862,9 @@ def __init__(self, *args: Any, **kwargs: Any):
self.safe_decorator = False
self.async_function = False
- self.uncheckpointed_statements: Set[Statement] = set()
- self.uncheckpointed_before_continue: Set[Statement] = set()
- self.uncheckpointed_before_break: Set[Statement] = set()
+ self.uncheckpointed_statements: set[Statement] = set()
+ self.uncheckpointed_before_continue: set[Statement] = set()
+ self.uncheckpointed_before_break: set[Statement] = set()
self.default = self.get_state()
@@ -913,7 +901,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.set_state(outer)
# error if function exits or returns with uncheckpointed statements
- def check_function_exit(self, node: Union[ast.Return, ast.AsyncFunctionDef]):
+ def check_function_exit(self, node: ast.Return | ast.AsyncFunctionDef):
for statement in self.uncheckpointed_statements:
self.error(
"TRIO108" if self.has_yield else "TRIO107",
@@ -937,7 +925,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
self.set_state(outer)
# checkpoint functions
- def visit_Await(self, node: Union[ast.Await, ast.Raise]):
+ def visit_Await(self, node: ast.Await | ast.Raise):
# the expression being awaited is not checkpointed
# so only set checkpoint after the await node
self.generic_visit(node)
@@ -994,7 +982,7 @@ def visit_Try(self, node: ast.Try):
try_checkpoint = self.uncheckpointed_statements
# check that all except handlers checkpoint (await or most likely raise)
- except_uncheckpointed_statements: Set[Statement] = set()
+ except_uncheckpointed_statements: set[Statement] = set()
for handler in node.handlers:
# enter with worst case of try
@@ -1032,7 +1020,7 @@ def visit_Try(self, node: ast.Try):
self.uncheckpointed_statements.difference_update(added)
# valid checkpoint if both body and orelse checkpoint
- def visit_If(self, node: Union[ast.If, ast.IfExp]):
+ def visit_If(self, node: ast.If | ast.IfExp):
# visit condition
self.visit_nodes(node.test)
outer = self.uncheckpointed_statements.copy()
@@ -1055,7 +1043,7 @@ def visit_If(self, node: Union[ast.If, ast.IfExp]):
# after completing all of loop body, and after any continues.
# yield in else have same requirement
# state after the loop same as above, and in addition the state at any break
- def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
+ def visit_loop(self, node: ast.While | ast.For | ast.AsyncFor):
# visit condition
infinite_loop = False
if isinstance(node, ast.While):
@@ -1307,8 +1295,8 @@ def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
- values: Optional[Sequence[Any]],
- option_string: Optional[str] = None,
+ values: Sequence[Any] | None,
+ option_string: str | None = None,
):
assert values is not None
assert option_string is not None
@@ -1327,7 +1315,7 @@ def __init__(self, tree: ast.AST):
self._tree = tree
@classmethod
- def from_filename(cls, filename: str) -> "Plugin":
+ def from_filename(cls, filename: str) -> Plugin:
with tokenize.open(filename) as f:
source = f.read()
return cls(ast.parse(source))
diff --git a/tests/conftest.py b/tests/conftest.py
index af3a52de..0bc2c0da 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,4 +1,4 @@
-from typing import List
+from __future__ import annotations
import pytest
@@ -15,7 +15,7 @@ def pytest_configure(config: pytest.Config):
)
-def pytest_collection_modifyitems(config: pytest.Config, items: List[pytest.Item]):
+def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]):
if config.getoption("--runfuzz"):
# --runfuzz given in cli: do not skip fuzz tests
return
diff --git a/tests/test_changelog_and_version.py b/tests/test_changelog_and_version.py
index e925a287..2c8c941d 100644
--- a/tests/test_changelog_and_version.py
+++ b/tests/test_changelog_and_version.py
@@ -1,9 +1,11 @@
"""Tests for flake8-trio package metadata."""
+from __future__ import annotations
+
import os
import re
import unittest
from pathlib import Path
-from typing import Dict, Iterable, NamedTuple, Set
+from typing import Iterable, NamedTuple
from test_flake8_trio import trio_test_files_regex
@@ -51,7 +53,7 @@ def test_version_increments_are_correct():
class test_messages_documented(unittest.TestCase):
def runTest(self):
- documented_errors: Dict[str, Set[str]] = {}
+ documented_errors: dict[str, set[str]] = {}
for filename in (
"CHANGELOG.md",
"README.md",
@@ -71,8 +73,8 @@ def runTest(self):
if re.match(trio_test_files_regex, f)
}
- unique_errors: Dict[str, Set[str]] = {}
- missing_errors: Dict[str, Set[str]] = {}
+ unique_errors: dict[str, set[str]] = {}
+ missing_errors: dict[str, set[str]] = {}
for key, codes in documented_errors.items():
unique_errors[key] = codes.copy()
missing_errors[key] = set()
diff --git a/tests/test_decorator.py b/tests/test_decorator.py
index 90937b8d..9399989d 100644
--- a/tests/test_decorator.py
+++ b/tests/test_decorator.py
@@ -1,7 +1,8 @@
+from __future__ import annotations
+
import ast
import sys
from argparse import Namespace
-from typing import Tuple
import pytest
from flake8.main.application import Application
@@ -18,7 +19,7 @@ def dec_list(*decorators: str) -> ast.Module:
return tree
-def wrap(decorators: Tuple[str, ...], decs2: str) -> bool:
+def wrap(decorators: tuple[str, ...], decs2: str) -> bool:
tree = dec_list(*decorators)
assert isinstance(tree.body[0], ast.AsyncFunctionDef)
return fnmatch_qualified_name(tree.body[0].decorator_list, decs2)
diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py
index 29489272..f529381a 100644
--- a/tests/test_flake8_trio.py
+++ b/tests/test_flake8_trio.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import ast
import copy
import itertools
@@ -8,7 +10,7 @@
import tokenize
import unittest
from pathlib import Path
-from typing import DefaultDict, Iterable, List, Sequence, Tuple, Type
+from typing import DefaultDict, Iterable, Sequence
import pytest
from flake8 import __version_info__ as flake8_version_info
@@ -22,7 +24,7 @@
trio_test_files_regex = re.compile(r"trio\d\d\d(_py.*)?.py")
-test_files: List[Tuple[str, str]] = sorted(
+test_files: list[tuple[str, str]] = sorted(
(os.path.splitext(f)[0].upper(), f)
for f in os.listdir("tests")
if re.match(trio_test_files_regex, f)
@@ -62,7 +64,7 @@ def test_eval(test: str, path: str):
assert test in Error_codes.keys(), "error code not defined in flake8_trio.py"
include = [test]
- expected: List[Error] = []
+ expected: list[Error] = []
with open(os.path.join("tests", path), encoding="utf-8") as file:
lines = file.readlines()
@@ -134,7 +136,7 @@ def visit_Await(self, node: ast.Await):
newnode = self.generic_visit(node.value)
return newnode
- def replace_async(self, node: ast.AST, target: Type[ast.AST]) -> ast.AST:
+ def replace_async(self, node: ast.AST, target: type[ast.AST]) -> ast.AST:
node = self.generic_visit(node)
newnode = target()
newnode.__dict__ = node.__dict__
@@ -188,8 +190,8 @@ def assert_expected_errors(plugin: Plugin, include: Iterable[str], *expected: Er
def print_first_diff(errors: Sequence[Error], expected: Sequence[Error]):
- first_error_line: List[Error] = []
- first_expected_line: List[Error] = []
+ first_error_line: list[Error] = []
+ first_expected_line: list[Error] = []
for err, exp in zip(errors, expected):
if err == exp:
continue
diff --git a/typings/flake8/__init__.pyi b/typings/flake8/__init__.pyi
index bf8f2de6..2745afa7 100644
--- a/typings/flake8/__init__.pyi
+++ b/typings/flake8/__init__.pyi
@@ -5,8 +5,6 @@ This type stub file was generated by pyright.
from __future__ import annotations
import logging
-import sys
-from typing import Dict, Tuple
"""Top-level module for Flake8.
@@ -21,8 +19,8 @@ This module
"""
LOG: logging.Logger = ...
__version__: str = ...
-__version_info__: Tuple[int, ...] = ...
-_VERBOSITY_TO_LOG_LEVEL: Dict[int, int] = ...
+__version_info__: tuple[int, ...] = ...
+_VERBOSITY_TO_LOG_LEVEL: dict[int, int] = ...
LOG_FORMAT: str = ...
def configure_logging(
diff --git a/typings/flake8/options/manager.pyi b/typings/flake8/options/manager.pyi
index 89e2d324..540a60f3 100644
--- a/typings/flake8/options/manager.pyi
+++ b/typings/flake8/options/manager.pyi
@@ -3,18 +3,14 @@ This type stub file was generated by pyright.
Generated for flake8 5, so OptionManager signature is incorrect for flake8 6
"""
+from __future__ import annotations
+
import argparse
from typing import (
Any,
Callable,
- Dict,
- List,
Mapping,
- Optional,
Sequence,
- Tuple,
- Type,
- Union,
)
from flake8.plugins.finder import Plugins
@@ -22,7 +18,7 @@ from flake8.plugins.finder import Plugins
"""Option handling and Option management logic."""
LOG = ...
_ARG = ...
-_optparse_callable_map: Dict[str, Union[Type[Any], _ARG]] = ...
+_optparse_callable_map: dict[str, type[Any] | _ARG] = ...
class _CallbackAction(argparse.Action):
"""Shim for optparse-style callback actions."""
@@ -32,15 +28,15 @@ class _CallbackAction(argparse.Action):
*args: Any,
callback: Callable[..., Any],
callback_args: Sequence[Any] = ...,
- callback_kwargs: Optional[Dict[str, Any]] = ...,
+ callback_kwargs: dict[str, Any] | None = ...,
**kwargs: Any,
) -> None: ...
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
- values: Optional[Union[Sequence[str], str]],
- option_string: Optional[str] = ...,
+ values: Sequence[str] | str | None,
+ option_string: str | None = ...,
) -> None: ...
class Option:
@@ -48,21 +44,21 @@ class Option:
def __init__(
self,
- short_option_name: Union[str, _ARG] = ...,
- long_option_name: Union[str, _ARG] = ...,
- action: Union[str, Type[argparse.Action], _ARG] = ...,
- default: Union[Any, _ARG] = ...,
- type: Union[str, Callable[..., Any], _ARG] = ...,
- dest: Union[str, _ARG] = ...,
- nargs: Union[int, str, _ARG] = ...,
- const: Union[Any, _ARG] = ...,
- choices: Union[Sequence[Any], _ARG] = ...,
- help: Union[str, _ARG] = ...,
- metavar: Union[str, _ARG] = ...,
- callback: Union[Callable[..., Any], _ARG] = ...,
- callback_args: Union[Sequence[Any], _ARG] = ...,
- callback_kwargs: Union[Mapping[str, Any], _ARG] = ...,
- required: Union[bool, _ARG] = ...,
+ short_option_name: str | _ARG = ...,
+ long_option_name: str | _ARG = ...,
+ action: str | type[argparse.Action] | _ARG = ...,
+ default: Any | _ARG = ...,
+ type: str | Callable[..., Any] | _ARG = ...,
+ dest: str | _ARG = ...,
+ nargs: int | str | _ARG = ...,
+ const: Any | _ARG = ...,
+ choices: Sequence[Any] | _ARG = ...,
+ help: str | _ARG = ...,
+ metavar: str | _ARG = ...,
+ callback: Callable[..., Any] | _ARG = ...,
+ callback_args: Sequence[Any] | _ARG = ...,
+ callback_kwargs: Mapping[str, Any] | _ARG = ...,
+ required: bool | _ARG = ...,
parse_from_config: bool = ...,
comma_separated_list: bool = ...,
normalize_paths: bool = ...,
@@ -130,14 +126,14 @@ class Option:
"""
...
@property
- def filtered_option_kwargs(self) -> Dict[str, Any]:
+ def filtered_option_kwargs(self) -> dict[str, Any]:
"""Return any actually-specified arguments."""
...
def __repr__(self) -> str: ...
def normalize(self, value: Any, *normalize_args: str) -> Any:
"""Normalize the value based on the option configuration."""
...
- def to_argparse(self) -> Tuple[List[str], Dict[str, Any]]:
+ def to_argparse(self) -> tuple[list[str], dict[str, Any]]:
"""Convert a Flake8 Option to argparse ``add_argument`` arguments."""
...
@@ -149,7 +145,7 @@ class OptionManager:
*,
version: str,
plugin_versions: str,
- parents: List[argparse.ArgumentParser],
+ parents: list[argparse.ArgumentParser],
) -> None:
"""Initialize an instance of an OptionManager.
@@ -197,8 +193,8 @@ class OptionManager:
...
def parse_args(
self,
- args: Optional[Sequence[str]] = ...,
- values: Optional[argparse.Namespace] = ...,
+ args: Sequence[str] | None = ...,
+ values: argparse.Namespace | None = ...,
) -> argparse.Namespace:
"""Proxy to calling the OptionParser's parse_args method."""
...