From ec99558f326d0ae9cb5051e43908342181f62b54 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 6 Nov 2025 11:08:00 -0800 Subject: [PATCH] Start lifting over unions and add missing string functions --- tests/test_type_eval.py | 70 +++++++++++++++++++++ typemap/type_eval/_typing_inspect.py | 4 +- typemap/typing.py | 91 +++++++++++++++++++--------- 3 files changed, 136 insertions(+), 29 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 1cc5506..91be373 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -5,11 +5,15 @@ from typemap.typing import ( NewProtocol, Member, + GetAttr, GetName, GetType, Iter, Attrs, Is, + Uppercase, + StrConcat, + StrSlice, ) from typemap.type_eval import eval_typing @@ -89,3 +93,69 @@ def test_eval_types_3(): class F[bool]: fff: bool """) + + +class TA: + x: int + y: list[float] + z: TB + + +class TB: + x: str + y: list[object] + + +def test_type_getattr_union_1(): + d = eval_typing(GetAttr[TA | TB, Literal["x"]]) + assert d == int | str + + +def test_type_getattr_union_2(): + d = eval_typing(GetAttr[TA, Literal["x"] | Literal["y"]]) + assert d == int | list[float] + + +def test_type_getattr_union_3(): + d = eval_typing(GetAttr[TA | TB, Literal["x"] | Literal["y"]]) + assert d == int | list[float] | str | list[object] + + +def test_type_getattr_union_4(): + d = eval_typing(GetAttr[TA, Literal["x", "y"]]) + assert d == int | list[float] + + +def test_type_getattr_union_5(): + d = eval_typing(GetAttr[TA, Literal["x", "y"] | Literal["z"]]) + assert d == int | list[float] | TB + + +def test_type_strings_1(): + d = eval_typing(Uppercase[Literal["foo"]]) + assert d == Literal["FOO"] + + +def test_type_strings_2(): + d = eval_typing(Uppercase[Literal["foo", "bar"]]) + assert d == Literal["FOO"] | Literal["BAR"] + + +def test_type_strings_3(): + d = eval_typing(StrConcat[Literal["foo"], Literal["bar"]]) + assert d == Literal["foobar"] + + +def test_type_strings_4(): + d = eval_typing(StrConcat[Literal["a", "b"], Literal["c", "d"]]) + assert d == Literal["ac"] | Literal["ad"] | Literal["bc"] | Literal["bd"] + + +def test_type_strings_5(): + d = eval_typing(StrSlice[Literal["abcd"], Literal[0], Literal[1]]) + assert d == Literal["a"] + + +def test_type_strings_6(): + d = eval_typing(StrSlice[Literal["abcd"], Literal[1], Literal[None]]) + assert d == Literal["bcd"] diff --git a/typemap/type_eval/_typing_inspect.py b/typemap/type_eval/_typing_inspect.py index 17d5d79..008e37e 100644 --- a/typemap/type_eval/_typing_inspect.py +++ b/typemap/type_eval/_typing_inspect.py @@ -127,7 +127,9 @@ def is_literal(t: Any) -> bool: def get_head(t: Any) -> type | None: if is_generic_alias(t): - return get_origin(t) + return get_head(get_origin(t)) + elif is_eval_proxy(t): + return get_head(t.__origin__) elif isinstance(t, type): return t else: diff --git a/typemap/typing.py b/typemap/typing.py index 3cac5fc..d6ee566 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -1,6 +1,8 @@ from dataclasses import dataclass +import functools import inspect +import itertools import types import typing @@ -109,9 +111,42 @@ def get_annotated_type_hints(cls, **kwargs): return hints +def _split_args(func): + @functools.wraps(func) + def wrapper(self, arg): + if isinstance(arg, tuple): + return func(self, *arg) + else: + return func(self, arg) + + return wrapper + + +def _union_elems(tp): + tp = type_eval.eval_typing(tp) + if isinstance(tp, types.UnionType): + return tuple(y for x in tp.__args__ for y in _union_elems(x)) + elif _typing_inspect.is_literal(tp) and len(tp.__args__) > 1: + return tuple(typing.Literal[x] for x in tp.__args__) + else: + return (tp,) + + +def _lift_over_unions(func): + @functools.wraps(func) + @_split_args + def wrapper(self, *args): + args2 = [_union_elems(x) for x in args] + # XXX: Never + parts = [func(self, *x) for x in itertools.product(*args2)] + return typing.Union[*parts] + + return wrapper + + @_SpecialForm +@_lift_over_unions def Attrs(self, tp): - # TODO: Support unions o = type_eval.eval_typing(tp) hints = get_annotated_type_hints(o, include_extras=True) @@ -141,6 +176,8 @@ def _ann(x): for _i, p in enumerate(sig.parameters.values()): # XXX: what should we do about self? # should we track classmethod/staticmethod somehow? + # mypy stores all this stuff in the SymbolNodes (FuncDef, etc), + # even though it kind of really is a type/descriptor thing # if i == 0 and is_method: # continue has_name = p.kind in ( @@ -166,8 +203,8 @@ def _ann(x): @_SpecialForm +@_lift_over_unions def Members(self, tp): - # TODO: Support unions o = type_eval.eval_typing(tp) hints = get_annotated_type_hints(o, include_extras=True) @@ -214,20 +251,16 @@ def Iter(self, tp): @_SpecialForm def FromUnion(self, tp): - tp = type_eval.eval_typing(tp) - if isinstance(tp, types.UnionType): - return tuple[*tp.__args__] - else: - return tuple[tp] + return tuple[*_union_elems(tp)] ################################################################## @_SpecialForm -def GetAttr(self, arg): - # TODO: Unions, the prop missing, etc! - lhs, prop = arg +@_lift_over_unions +def GetAttr(self, lhs, prop): + # TODO: the prop missing, etc! # XXX: extras? name = _from_literal(type_eval.eval_typing(prop)) return typing.get_type_hints(type_eval.eval_typing(lhs))[name] @@ -260,9 +293,8 @@ def _get_args(tp, base) -> typing.Any: @_SpecialForm -def GetArg(self, arg) -> typing.Any: - # XXX: Unions - tp, base, idx = arg +@_lift_over_unions +def GetArg(self, tp, base, idx) -> typing.Any: args = _get_args(tp, base) if args is None: return typing.Never @@ -275,10 +307,12 @@ def GetArg(self, arg) -> typing.Any: ################################################################## +# N.B: These handle unions on their own + @_SpecialForm -def IsSubtype(self, arg): - lhs, rhs = arg +@_split_args +def IsSubtype(self, lhs, rhs): return type_eval.issubtype( type_eval.eval_typing(lhs), type_eval.eval_typing(rhs), @@ -286,8 +320,8 @@ def IsSubtype(self, arg): @_SpecialForm -def IsSubSimilar(self, arg): - lhs, rhs = arg +@_split_args +def IsSubSimilar(self, lhs, rhs): return type_eval.issubsimilar( type_eval.eval_typing(lhs), type_eval.eval_typing(rhs), @@ -299,21 +333,22 @@ def IsSubSimilar(self, arg): ################################################################## -# TODO: unions! Slice, Concat - -class _StringLiteralOp: - def __init__(self, op: typing.Callable[[str], str]): - self.op = op +def _string_literal_op(op): + @_SpecialForm + @_lift_over_unions + def func(self, *args): + return typing.Literal[op(*[_from_literal(x) for x in args])] - def __getitem__(self, arg): - return typing.Literal[self.op(_from_literal(arg))] + return func -Uppercase = _StringLiteralOp(op=str.upper) -Lowercase = _StringLiteralOp(op=str.lower) -Capitalize = _StringLiteralOp(op=str.capitalize) -Uncapitalize = _StringLiteralOp(op=lambda s: s[0:1].lower() + s[1:]) +Uppercase = _string_literal_op(op=str.upper) +Lowercase = _string_literal_op(op=str.lower) +Capitalize = _string_literal_op(op=str.capitalize) +Uncapitalize = _string_literal_op(op=lambda s: s[0:1].lower() + s[1:]) +StrConcat = _string_literal_op(op=lambda s, t: s + t) +StrSlice = _string_literal_op(op=lambda s, start, end: s[start:end]) ##################################################################