Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
exclude: ^python/tests/__snapshots__/
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3
rev: v0.15.0
hooks:
- id: ruff-check
args: [--fix]
- id: ruff-format
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.7
rev: 0.10.0
hooks:
- id: uv-lock
2 changes: 1 addition & 1 deletion docs/explanation/2023_11_17_pytensor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
")\n",
"converter(int, IntTuple, lambda i: IntTuple(Int(i64(i))))\n",
"converter(i64, IntTuple, lambda i: IntTuple(Int(i)))\n",
"converter(Int, IntTuple, lambda i: IntTuple(i))\n",
"converter(Int, IntTuple, IntTuple)\n",
"\n",
"\n",
"@egraph.register\n",
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import config, ipython_magic # noqa: F401
from .bindings import EggSmolError, StageInfo, TimeOnly, WithPlan # noqa: F401
from .builtins import * # noqa: UP029
from .builtins import *
from .conversion import *
from .deconstruct import *
from .egraph import *
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def bool_le(self, other: BigIntLike) -> Bool: ...
def bool_ge(self, other: BigIntLike) -> Bool: ...


converter(i64, BigInt, lambda i: BigInt(i))
converter(i64, BigInt, BigInt)

BigIntLike: TypeAlias = BigInt | i64Like

Expand Down
48 changes: 24 additions & 24 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))


converter(i64, Int, lambda x: Int(x))
converter(i64, Int, Int)

IntLike: TypeAlias = Int | i64Like

Expand Down Expand Up @@ -377,8 +377,8 @@ def __gt__(self, other: FloatLike) -> Boolean: ...
def __ge__(self, other: FloatLike) -> Boolean: ...


converter(float, Float, lambda x: Float(x))
converter(Int, Float, lambda x: Float.from_int(x))
converter(float, Float, Float)
converter(Int, Float, Float.from_int)


FloatLike: TypeAlias = Float | float | IntLike
Expand Down Expand Up @@ -521,7 +521,7 @@ def deselect(self, indices: TupleIntLike) -> TupleInt:
return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])


converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x))
converter(Vec[Int], TupleInt, TupleInt.from_vec)

TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike]

Expand Down Expand Up @@ -649,7 +649,7 @@ def product(self) -> TupleTupleInt:
)


converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x))
converter(Vec[TupleInt], TupleTupleInt, TupleTupleInt.from_vec)

TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike]

Expand Down Expand Up @@ -755,8 +755,8 @@ def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...


converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
converter(DType, IsDtypeKind, IsDtypeKind.dtype)
converter(str, IsDtypeKind, IsDtypeKind.string)
converter(
tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
)
Expand Down Expand Up @@ -922,8 +922,8 @@ def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
return TupleValue(ti.length(), lambda i: Value.int(ti[i]))


converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x))
converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x))
converter(Vec[Value], TupleValue, TupleValue.from_vec)
converter(TupleInt, TupleValue, TupleValue.from_tuple_int)

TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike

Expand Down Expand Up @@ -1073,9 +1073,9 @@ def ndarray(cls, key: NDArray) -> IndexKey:


converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
converter(Int, IndexKey, lambda i: IndexKey.int(i))
converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
converter(Int, IndexKey, IndexKey.int)
converter(Slice, IndexKey, IndexKey.slice)
converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis)


class Device(Expr, ruleset=array_api_ruleset): ...
Expand Down Expand Up @@ -1232,13 +1232,13 @@ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...

NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike

converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
converter(Value, NDArray, lambda v: NDArray.scalar(v))
converter(NDArray, IndexKey, IndexKey.ndarray)
converter(Value, NDArray, NDArray.scalar)
# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
# to prefer upcasting in the other direction when we can, which is safer at runtime
converter(NDArray, Value, lambda n: n.to_value(), 100)
converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
converter(TupleValue, NDArray, NDArray.vector)
converter(TupleInt, TupleValue, TupleValue.from_tuple_int)


@array_api_ruleset.register
Expand Down Expand Up @@ -1322,7 +1322,7 @@ def eval(self) -> tuple[NDArray, ...]:
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)


converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
converter(Vec[NDArray], TupleNDArray, TupleNDArray.from_vec)

TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike]

Expand Down Expand Up @@ -1371,7 +1371,7 @@ def some(cls, value: Boolean) -> OptionalBool: ...


converter(type(None), OptionalBool, lambda _: OptionalBool.none)
converter(Boolean, OptionalBool, lambda x: OptionalBool.some(x))
converter(Boolean, OptionalBool, OptionalBool.some)


class OptionalDType(Expr, ruleset=array_api_ruleset):
Expand All @@ -1382,7 +1382,7 @@ def some(cls, value: DType) -> OptionalDType: ...


converter(type(None), OptionalDType, lambda _: OptionalDType.none)
converter(DType, OptionalDType, lambda x: OptionalDType.some(x))
converter(DType, OptionalDType, OptionalDType.some)


class OptionalDevice(Expr, ruleset=array_api_ruleset):
Expand All @@ -1393,7 +1393,7 @@ def some(cls, value: Device) -> OptionalDevice: ...


converter(type(None), OptionalDevice, lambda _: OptionalDevice.none)
converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))
converter(Device, OptionalDevice, OptionalDevice.some)


class OptionalTupleInt(Expr, ruleset=array_api_ruleset):
Expand All @@ -1404,7 +1404,7 @@ def some(cls, value: TupleIntLike) -> OptionalTupleInt: ...


converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none)
converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
converter(TupleInt, OptionalTupleInt, OptionalTupleInt.some)


class IntOrTuple(Expr, ruleset=array_api_ruleset):
Expand All @@ -1417,8 +1417,8 @@ def int(cls, value: Int) -> IntOrTuple: ...
def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...


converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
converter(Int, IntOrTuple, IntOrTuple.int)
converter(TupleInt, IntOrTuple, IntOrTuple.tuple)


class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
Expand All @@ -1429,7 +1429,7 @@ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...


converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)


@function
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/exp/array_api_loopnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
)
yield rewrite(s.select(axis), subsume=True).to(
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
ShapeAPI(TupleInt.range(dims.length()).filter(axis.contains).map(lambda i: dims[i]))
)
yield rewrite(s.to_tuple(), subsume=True).to(dims)

Expand Down
21 changes: 13 additions & 8 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,14 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
return NDArray(
outshape,
X.dtype,
lambda k: LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.indices()
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
.sqrt(),
lambda k: (
LoopNestAPI
.from_tuple(reduce_axis)
.unwrap()
.indices()
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
.sqrt()
),
)


Expand All @@ -224,9 +227,11 @@ def linalg_norm_v2(X: NDArrayLike, axis: TupleIntLike) -> NDArray:
return NDArray(
X.shape.deselect(axis),
X.dtype,
lambda k: ndindex(X.shape.select(axis))
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
.sqrt(),
lambda k: (
ndindex(X.shape.select(axis))
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
.sqrt()
),
)


Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_convert_to_generic():
class G(BuiltinExpr, Generic[T]):
def __init__(self, x: T) -> None: ...

converter(i64, G[i64], lambda x: G(x))
converter(i64, G[i64], G)
assert expr_parts(convert(10, G[i64])) == expr_parts(G(i64(10)))

with pytest.raises(ConvertError):
Expand All @@ -114,7 +114,7 @@ def test_convert_to_unbound_generic():
class G(BuiltinExpr, Generic[T]):
def __init__(self, x: i64) -> None: ...

converter(i64, G, lambda x: G[get_type_args()[0]](x)) # type: ignore[misc, operator]
converter(i64, G, G[get_type_args()[0]]) # type: ignore[misc, operator]
assert expr_parts(convert(10, G[String])) == expr_parts(G[String](i64(10)))


Expand Down
8 changes: 4 additions & 4 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,14 +535,14 @@ def _global_make_tuple(x):


def test_eval_fn_globals():
assert EGraph().extract(PyObject(lambda x: _global_make_tuple(x))(PyObject.from_int(1))).value == (1,)
assert EGraph().extract(PyObject(_global_make_tuple)(PyObject.from_int(1))).value == (1,)


def test_eval_fn_locals():
def _locals_make_tuple(x):
return (x,)

assert EGraph().extract(PyObject(lambda x: _locals_make_tuple(x))(PyObject.from_int(1))).value == (1,)
assert EGraph().extract(PyObject(_locals_make_tuple)(PyObject.from_int(1))).value == (1,)


def test_lazy_types():
Expand Down Expand Up @@ -1459,9 +1459,9 @@ def __contains__(self, item: int) -> bool:
pytest.param(lambda: int(m), 1000, id="int"),
pytest.param(lambda: float(m), 100.0, id="float"),
pytest.param(lambda: complex(m), 1 + 0j, id="complex"),
pytest.param(lambda: m.__index__(), 20, id="index"),
pytest.param(m.__index__, 20, id="index"),
pytest.param(lambda: len(m), 10, id="len"),
pytest.param(lambda: m.__length_hint__(), 5, id="length_hint"),
pytest.param(m.__length_hint__, 5, id="length_hint"),
pytest.param(lambda: list(m), [1], id="iter"),
pytest.param(lambda: list(reversed(m)), [10], id="reversed"),
pytest.param(lambda: 1 in m, True, id="contains"),
Expand Down
8 changes: 4 additions & 4 deletions python/tests/test_unstable_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A:

@r.register
def _rewrite(a: A):
yield rewrite(transform_a(a)).to(apply_f(lambda x: my_transform_a(x), a))
yield rewrite(transform_a(a)).to(apply_f(my_transform_a, a))

assert check_eq(transform_a(A()), my_transform_a(A()), r * 10)

Expand All @@ -276,7 +276,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A:

@ruleset
def my_ruleset(a: A):
yield rewrite(transform_a(a)).to(apply_f(lambda x: my_transform_a(x), a))
yield rewrite(transform_a(a)).to(apply_f(my_transform_a, a))

assert check_eq(transform_a(A()), my_transform_a(A()), (my_ruleset | apply_ruleset) * 10)

Expand All @@ -296,7 +296,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A:

@function(ruleset=r)
def transform_a(a: A) -> A:
return apply_f(lambda x: my_transform_a(x), a)
return apply_f(my_transform_a, a)

assert check_eq(transform_a(A()), my_transform_a(A()), r * 10)

Expand Down Expand Up @@ -325,7 +325,7 @@ def higher_order(f: Callable[[A], A]) -> A: ...
@function
def transform_a(a: A) -> A: ...

v = higher_order(lambda a: transform_a(a))
v = higher_order(transform_a)
assert str(v) == "higher_order(lambda a: transform_a(a))"

def test_multiple_same(self):
Expand Down
Loading