diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 104bb21..d092473 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,13 +5,13 @@ ci: repos: - repo: https://github.com/crate-ci/typos - rev: v1 + rev: v1.31.1 hooks: - id: typos args: [] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.4 + rev: v0.11.7 hooks: - id: ruff args: ["--fix", "--unsafe-fixes"] diff --git a/src/app_model/expressions/_context_keys.py b/src/app_model/expressions/_context_keys.py index 6d323f2..d5f3cf0 100644 --- a/src/app_model/expressions/_context_keys.py +++ b/src/app_model/expressions/_context_keys.py @@ -97,7 +97,8 @@ def __init__( *, id: str = "", # optional because of __set_name__ ) -> None: - super().__init__(id or "") + bound = type(default_value) if default_value is not MISSING else None + super().__init__(id or "", bound=bound) self._default_value = default_value self._getter = getter self._description = description diff --git a/src/app_model/expressions/_expressions.py b/src/app_model/expressions/_expressions.py index cba12b0..7de3316 100644 --- a/src/app_model/expressions/_expressions.py +++ b/src/app_model/expressions/_expressions.py @@ -30,9 +30,21 @@ from pydantic.annotated_handlers import GetCoreSchemaHandler from pydantic_core import core_schema + from typing_extensions import TypedDict, Unpack from ._context_keys import ContextKey + # Used for node end positions in constructor keyword arguments + _EndPositionT = TypeVar("_EndPositionT", int, None) + + # Corresponds to the names in the `_attributes` + # class variable which is non-empty in certain AST nodes + class _Attributes(TypedDict, Generic[_EndPositionT], total=False): + lineno: int + col_offset: int + end_lineno: _EndPositionT + end_col_offset: _EndPositionT + def parse_expression(expr: Expr | str) -> Expr: """Parse string expression into an [`Expr`][app_model.expressions.Expr] instance. @@ -326,7 +338,7 @@ def __invert__(self) -> UnaryOp[T]: return UnaryOp(ast.Not(), self) def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]: - rv = list(super().__reduce_ex__(protocol)) + rv: list[Any] = list(super().__reduce_ex__(protocol)) rv[1] = tuple(getattr(self, f) for f in self._fields) return tuple(rv) @@ -368,24 +380,45 @@ def _iter_names(self) -> Iterator[str]: class Name(Expr[T], ast.Name): """A variable name. - `id` holds the name as a string. + Parameters + ---------- + id : str + The name of the variable. + bound : Any | None + The type of the variable represented by this name (i.e. the type to which this + name evaluates to when used in an expression). This is used to provide type + hints when evaluating the expression. If `None`, the type is not known. """ - def __init__(self, id: str, ctx: ast.expr_context = LOAD, **kwargs: Any) -> None: - kwargs["ctx"] = LOAD - super().__init__(id, **kwargs) + def __init__( + self, + id: str, + ctx: ast.expr_context = LOAD, + *, + bound: type[T] | None = None, + **kwargs: Unpack[_Attributes], + ) -> None: + super().__init__(id, ctx=ctx, **kwargs) + self.bound = bound class Constant(Expr[V], ast.Constant): """A constant value. - The `value` attribute contains the Python object it represents. - types supported: NoneType, str, bytes, bool, int, float + Parameters + ---------- + value : V + the Python object this constant represents. + Types supported: NoneType, str, bytes, bool, int, float + kind : str | None + The kind of constant. This is used to provide type hints when """ value: V - def __init__(self, value: V, kind: str | None = None, **kwargs: Any) -> None: + def __init__( + self, value: V, kind: str | None = None, **kwargs: Unpack[_Attributes] + ) -> None: _valid_type = (type(None), str, bytes, bool, int, float) if not isinstance(value, _valid_type): raise TypeError(f"Constants must be type: {_valid_type!r}") @@ -405,13 +438,10 @@ def __init__( left: Expr, ops: Sequence[ast.cmpop], comparators: Sequence[Expr], - **kwargs: Any, + **kwargs: Unpack[_Attributes], ) -> None: super().__init__( - Expr._cast(left), - ops, - [Expr._cast(c) for c in comparators], - **kwargs, + Expr._cast(left), ops, [Expr._cast(c) for c in comparators], **kwargs ) @@ -426,9 +456,9 @@ def __init__( left: T | Expr[T], op: ast.operator, right: T | Expr[T], - **k: Any, + **kwargs: Unpack[_Attributes], ) -> None: - super().__init__(Expr._cast(left), op, Expr._cast(right), **k) + super().__init__(Expr._cast(left), op, Expr._cast(right), **kwargs) class BoolOp(Expr[T], ast.BoolOp): @@ -445,7 +475,7 @@ def __init__( self, op: ast.boolop, values: Sequence[ConstType | Expr], - **kwargs: Any, + **kwargs: Unpack[_Attributes], ): super().__init__(op, [Expr._cast(v) for v in values], **kwargs) @@ -456,7 +486,9 @@ class UnaryOp(Expr[T], ast.UnaryOp): `op` is the operator, and `operand` any expression node. """ - def __init__(self, op: ast.unaryop, operand: Expr, **kwargs: Any) -> None: + def __init__( + self, op: ast.unaryop, operand: Expr, **kwargs: Unpack[_Attributes] + ) -> None: super().__init__(op, Expr._cast(operand), **kwargs) @@ -466,7 +498,9 @@ class IfExp(Expr, ast.IfExp): `body` if `test` else `orelse` """ - def __init__(self, test: Expr, body: Expr, orelse: Expr, **kwargs: Any) -> None: + def __init__( + self, test: Expr, body: Expr, orelse: Expr, **kwargs: Unpack[_Attributes] + ) -> None: super().__init__( Expr._cast(test), Expr._cast(body), Expr._cast(orelse), **kwargs ) @@ -479,10 +513,12 @@ class Tuple(Expr, ast.Tuple): """ def __init__( - self, elts: Sequence[Expr], ctx: ast.expr_context = LOAD, **kwargs: Any + self, + elts: Sequence[Expr], + ctx: ast.expr_context = LOAD, + **kwargs: Unpack[_Attributes], ) -> None: - kwargs["ctx"] = ctx - super().__init__(elts=[Expr._cast(e) for e in elts], **kwargs) + super().__init__(elts=[Expr._cast(e) for e in elts], ctx=ctx, **kwargs) class List(Expr, ast.List): @@ -492,10 +528,12 @@ class List(Expr, ast.List): """ def __init__( - self, elts: Sequence[Expr], ctx: ast.expr_context = LOAD, **kwargs: Any + self, + elts: Sequence[Expr], + ctx: ast.expr_context = LOAD, + **kwargs: Unpack[_Attributes], ) -> None: - kwargs["ctx"] = ctx - super().__init__(elts=[Expr._cast(e) for e in elts], **kwargs) + super().__init__(elts=[Expr._cast(e) for e in elts], ctx=ctx, **kwargs) class Set(Expr, ast.Set): @@ -504,7 +542,7 @@ class Set(Expr, ast.Set): `elts` is a list of expressions. """ - def __init__(self, elts: Sequence[Expr], **kwargs: Any) -> None: + def __init__(self, elts: Sequence[Expr], **kwargs: Unpack[_Attributes]) -> None: super().__init__(elts=[Expr._cast(e) for e in elts], **kwargs) diff --git a/src/app_model/types/__init__.py b/src/app_model/types/__init__.py index 39a6530..72acf38 100644 --- a/src/app_model/types/__init__.py +++ b/src/app_model/types/__init__.py @@ -19,7 +19,9 @@ from ._menu_rule import MenuItem, MenuItemBase, MenuRule, SubmenuItem if TYPE_CHECKING: - from typing import Callable, TypeAlias + from typing import Callable + + from typing_extensions import TypeAlias from ._icon import IconOrDict as IconOrDict from ._keybinding_rule import KeyBindingRuleDict as KeyBindingRuleDict diff --git a/tests/test_context/test_expressions.py b/tests/test_context/test_expressions.py index 2f4a787..408eea1 100644 --- a/tests/test_context/test_expressions.py +++ b/tests/test_context/test_expressions.py @@ -8,7 +8,8 @@ def test_names(): - assert Name("n").eval({"n": 5}) == 5 + expr = Name("n", bound=int) + assert expr.eval({"n": 5}) == 5 # currently, evaludating with a missing name is an error. with pytest.raises(NameError):