Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 42 additions & 54 deletions flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,14 @@
Pairs well with flake8-async and flake8-bugbear.
"""

from __future__ import annotations

import argparse
import ast
import tokenize
from argparse import Namespace
from fnmatch import fnmatch
from typing import (
Any,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from typing import Any, Iterable, NamedTuple, Sequence, Union, cast

from flake8.options.manager import OptionManager

Expand Down Expand Up @@ -95,7 +85,7 @@ def __eq__(self, other: Any) -> bool:

def get_matching_call(
node: ast.AST, *names: str, base: str = "trio"
) -> Optional[Tuple[ast.Call, str]]:
) -> tuple[ast.Call, str] | None:
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
Expand Down Expand Up @@ -150,7 +140,7 @@ def __repr__(self) -> str:
class Flake8TrioVisitor(ast.NodeVisitor):
def __init__(self, options: Namespace):
super().__init__()
self._problems: List[Error] = []
self._problems: list[Error] = []
self.suppress_errors = False
self.options = options

Expand All @@ -160,9 +150,7 @@ def run(cls, tree: ast.AST, options: Namespace) -> Iterable[Error]:
visitor.visit(tree)
yield from visitor._problems

def visit_nodes(
self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False
):
def visit_nodes(self, *nodes: ast.AST | Iterable[ast.AST], generic: bool = False):
if generic:
visit = self.generic_visit
else:
Expand All @@ -178,10 +166,10 @@ def error(self, error: str, node: HasLineCol, *args: object):
if not self.suppress_errors:
self._problems.append(Error(error, node.lineno, node.col_offset, *args))

def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
if not attrs:
attrs = tuple(self.__dict__.keys())
res: Dict[str, Any] = {}
res: dict[str, Any] = {}
for attr in attrs:
if attr == "_problems":
continue
Expand All @@ -191,7 +179,7 @@ def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
res[attr] = value
return res

def set_state(self, attrs: Dict[str, Any], copy: bool = False):
def set_state(self, attrs: dict[str, Any], copy: bool = False):
for attr, value in attrs.items():
if copy and hasattr(value, "copy"):
value = value.copy()
Expand All @@ -203,7 +191,7 @@ def walk(self, *body: ast.AST) -> Iterable[ast.AST]:


# ignores module and only checks the unqualified name of the decorator
def has_decorator(decorator_list: List[ast.expr], *names: str):
def has_decorator(decorator_list: list[ast.expr], *names: str):
for dec in decorator_list:
if (isinstance(dec, ast.Name) and dec.id in names) or (
isinstance(dec, ast.Attribute) and dec.attr in names
Expand All @@ -214,8 +202,8 @@ def has_decorator(decorator_list: List[ast.expr], *names: str):

# matches the fully qualified name against fnmatch pattern
# used to match decorators and methods to user-supplied patterns
def fnmatch_qualified_name(name_list: List[ast.expr], *patterns: str):
def construct_name(expr: ast.expr) -> Optional[str]:
def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str):
def construct_name(expr: ast.expr) -> str | None:
if isinstance(expr, ast.Call):
expr = expr.func
if isinstance(expr, ast.Name):
Expand Down Expand Up @@ -257,13 +245,13 @@ def __init__(self, *args: Any, **kwargs: Any):
self._safe_decorator = False

# 111
self._context_managers: List[VisitorMiscChecks.TrioContextManager] = []
self._nursery_call: Optional[VisitorMiscChecks.NurseryCall] = None
self._context_managers: list[VisitorMiscChecks.TrioContextManager] = []
self._nursery_call: VisitorMiscChecks.NurseryCall | None = None

self.defaults = self.get_state(copy=True)

# ---- 100, 101, 111, 112 ----
def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
def visit_With(self, node: ast.With | ast.AsyncWith):
self.check_for_trio100(node)
self.check_for_trio112(node)

Expand Down Expand Up @@ -305,7 +293,7 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
visit_AsyncWith = visit_With

# ---- 100 ----
def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
def check_for_trio100(self, node: ast.With | ast.AsyncWith):
# Context manager with no `await trio.X` call within
for item in (i.context_expr for i in node.items):
call = get_matching_call(item, *cancel_scope_names)
Expand All @@ -316,7 +304,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
self.error("TRIO100", item, f"trio.{call[1]}")

# ---- 101 ----
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
outer = self.get_state()
self.set_state(self.defaults, copy=True)

Expand Down Expand Up @@ -433,7 +421,7 @@ def visit_Name(self, node: ast.Name):
# if with has a withitem `trio.open_nursery() as <X>`,
# and the body is only a single expression <X>.start[_soon](),
# and does not pass <X> as a parameter to the expression
def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]):
def check_for_trio112(self, node: ast.With | ast.AsyncWith):
# body is single expression
if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr):
return
Expand Down Expand Up @@ -462,8 +450,8 @@ def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]):


# used in 102, 103 and 104
def critical_except(node: ast.ExceptHandler) -> Optional[Statement]:
def has_exception(node: Optional[ast.expr]) -> str:
def critical_except(node: ast.ExceptHandler) -> Statement | None:
def has_exception(node: ast.expr | None) -> str:
if isinstance(node, ast.Name) and node.id == "BaseException":
return "BaseException"
if (
Expand Down Expand Up @@ -496,7 +484,7 @@ class TrioScope:
def __init__(self, node: ast.Call, funcname: str):
self.node = node
self.funcname = funcname
self.variable_name: Optional[str] = None
self.variable_name: str | None = None
self.shielded: bool = False
self.has_timeout: bool = True

Expand All @@ -514,14 +502,14 @@ def __init__(self, node: ast.Call, funcname: str):

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._critical_scope: Optional[Statement] = None
self._trio_context_managers: List[Visitor102.TrioScope] = []
self._critical_scope: Statement | None = None
self._trio_context_managers: list[Visitor102.TrioScope] = []

# if we're inside a finally, and we're not inside a scope that doesn't have
# both a timeout and shield
def visit_Await(
self,
node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith],
node: ast.Await | ast.AsyncFor | ast.AsyncWith,
visit_children: bool = True,
):
if self._critical_scope is not None and not any(
Expand All @@ -533,7 +521,7 @@ def visit_Await(

visit_AsyncFor = visit_Await

def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
def visit_With(self, node: ast.With | ast.AsyncWith):
has_context_manager = False

# Check for a `with trio.<scope_creater>`
Expand Down Expand Up @@ -564,7 +552,7 @@ def visit_AsyncWith(self, node: ast.AsyncWith):

def critical_visit(
self,
node: Union[ast.ExceptHandler, Iterable[ast.AST]],
node: ast.ExceptHandler | Iterable[ast.AST],
block: Statement,
generic: bool = False,
):
Expand Down Expand Up @@ -612,7 +600,7 @@ def visit_Assign(self, node: ast.Assign):
class Visitor103_104(Flake8TrioVisitor):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.except_name: Optional[str] = ""
self.except_name: str | None = ""
self.unraised: bool = False
self.unraised_break: bool = False
self.unraised_continue: bool = False
Expand Down Expand Up @@ -659,7 +647,7 @@ def visit_Raise(self, node: ast.Raise):

self.generic_visit(node)

def visit_Return(self, node: Union[ast.Return, ast.Yield]):
def visit_Return(self, node: ast.Return | ast.Yield):
if self.unraised:
# Error: must re-raise
self.error("TRIO104", node)
Expand Down Expand Up @@ -709,7 +697,7 @@ def visit_If(self, node: ast.If):
# else always raises, and
# always raise before break
# or body always raises (before break) and is guaranteed to run at least once
def visit_For(self, node: Union[ast.For, ast.While]):
def visit_For(self, node: ast.For | ast.While):
if not self.unraised:
self.generic_visit(node)
return
Expand Down Expand Up @@ -829,7 +817,7 @@ class Visitor105(Flake8TrioVisitor):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
# keep a node stack so we can check whether calls are awaited
self.node_stack: List[ast.AST] = []
self.node_stack: list[ast.AST] = []

def visit(self, node: ast.AST):
self.node_stack.append(node)
Expand All @@ -854,7 +842,7 @@ def visit_Call(self, node: ast.Call):
self.generic_visit(node)


def empty_body(body: List[ast.stmt]) -> bool:
def empty_body(body: list[ast.stmt]) -> bool:
# Does the function body consist solely of `pass`, `...`, and (doc)string literals?
return all(
isinstance(stmt, ast.Pass)
Expand All @@ -874,9 +862,9 @@ def __init__(self, *args: Any, **kwargs: Any):
self.safe_decorator = False
self.async_function = False

self.uncheckpointed_statements: Set[Statement] = set()
self.uncheckpointed_before_continue: Set[Statement] = set()
self.uncheckpointed_before_break: Set[Statement] = set()
self.uncheckpointed_statements: set[Statement] = set()
self.uncheckpointed_before_continue: set[Statement] = set()
self.uncheckpointed_before_break: set[Statement] = set()

self.default = self.get_state()

Expand Down Expand Up @@ -913,7 +901,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.set_state(outer)

# error if function exits or returns with uncheckpointed statements
def check_function_exit(self, node: Union[ast.Return, ast.AsyncFunctionDef]):
def check_function_exit(self, node: ast.Return | ast.AsyncFunctionDef):
for statement in self.uncheckpointed_statements:
self.error(
"TRIO108" if self.has_yield else "TRIO107",
Expand All @@ -937,7 +925,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
self.set_state(outer)

# checkpoint functions
def visit_Await(self, node: Union[ast.Await, ast.Raise]):
def visit_Await(self, node: ast.Await | ast.Raise):
# the expression being awaited is not checkpointed
# so only set checkpoint after the await node
self.generic_visit(node)
Expand Down Expand Up @@ -994,7 +982,7 @@ def visit_Try(self, node: ast.Try):
try_checkpoint = self.uncheckpointed_statements

# check that all except handlers checkpoint (await or most likely raise)
except_uncheckpointed_statements: Set[Statement] = set()
except_uncheckpointed_statements: set[Statement] = set()

for handler in node.handlers:
# enter with worst case of try
Expand Down Expand Up @@ -1032,7 +1020,7 @@ def visit_Try(self, node: ast.Try):
self.uncheckpointed_statements.difference_update(added)

# valid checkpoint if both body and orelse checkpoint
def visit_If(self, node: Union[ast.If, ast.IfExp]):
def visit_If(self, node: ast.If | ast.IfExp):
# visit condition
self.visit_nodes(node.test)
outer = self.uncheckpointed_statements.copy()
Expand All @@ -1055,7 +1043,7 @@ def visit_If(self, node: Union[ast.If, ast.IfExp]):
# after completing all of loop body, and after any continues.
# yield in else have same requirement
# state after the loop same as above, and in addition the state at any break
def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
def visit_loop(self, node: ast.While | ast.For | ast.AsyncFor):
# visit condition
infinite_loop = False
if isinstance(node, ast.While):
Expand Down Expand Up @@ -1307,8 +1295,8 @@ def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Sequence[Any]],
option_string: Optional[str] = None,
values: Sequence[Any] | None,
option_string: str | None = None,
):
assert values is not None
assert option_string is not None
Expand All @@ -1327,7 +1315,7 @@ def __init__(self, tree: ast.AST):
self._tree = tree

@classmethod
def from_filename(cls, filename: str) -> "Plugin":
def from_filename(cls, filename: str) -> Plugin:
with tokenize.open(filename) as f:
source = f.read()
return cls(ast.parse(source))
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from __future__ import annotations

import pytest

Expand All @@ -15,7 +15,7 @@ def pytest_configure(config: pytest.Config):
)


def pytest_collection_modifyitems(config: pytest.Config, items: List[pytest.Item]):
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]):
if config.getoption("--runfuzz"):
# --runfuzz given in cli: do not skip fuzz tests
return
Expand Down
10 changes: 6 additions & 4 deletions tests/test_changelog_and_version.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for flake8-trio package metadata."""
from __future__ import annotations

import os
import re
import unittest
from pathlib import Path
from typing import Dict, Iterable, NamedTuple, Set
from typing import Iterable, NamedTuple

from test_flake8_trio import trio_test_files_regex

Expand Down Expand Up @@ -51,7 +53,7 @@ def test_version_increments_are_correct():

class test_messages_documented(unittest.TestCase):
def runTest(self):
documented_errors: Dict[str, Set[str]] = {}
documented_errors: dict[str, set[str]] = {}
for filename in (
"CHANGELOG.md",
"README.md",
Expand All @@ -71,8 +73,8 @@ def runTest(self):
if re.match(trio_test_files_regex, f)
}

unique_errors: Dict[str, Set[str]] = {}
missing_errors: Dict[str, Set[str]] = {}
unique_errors: dict[str, set[str]] = {}
missing_errors: dict[str, set[str]] = {}
for key, codes in documented_errors.items():
unique_errors[key] = codes.copy()
missing_errors[key] = set()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import ast
import sys
from argparse import Namespace
from typing import Tuple

import pytest
from flake8.main.application import Application
Expand All @@ -18,7 +19,7 @@ def dec_list(*decorators: str) -> ast.Module:
return tree


def wrap(decorators: Tuple[str, ...], decs2: str) -> bool:
def wrap(decorators: tuple[str, ...], decs2: str) -> bool:
tree = dec_list(*decorators)
assert isinstance(tree.body[0], ast.AsyncFunctionDef)
return fnmatch_qualified_name(tree.body[0].decorator_list, decs2)
Expand Down
Loading