diff --git a/python/tvm/expr.py b/python/tvm/expr.py index b265103360c6..7575037ebc11 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -16,7 +16,7 @@ """ # pylint: disable=missing-docstring from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node +from ._ffi.node import NodeBase, NodeGeneric, register_node from . import make as _make from . import _api_internal @@ -89,10 +89,10 @@ def __le__(self, other): return _make.LE(self, other) def __eq__(self, other): - return self.equal(other) + return EqualOp(self, other) def __ne__(self, other): - return _make.NE(self, other) + return NotEqualOp(self, other) def __gt__(self, other): return _make.GT(self, other) @@ -138,12 +138,71 @@ def astype(self, dtype): return _make.static_cast(dtype, self) +class EqualOp(NodeGeneric, ExprOp): + """Deferred equal operator. + + This is used to support sugar that a == b can either + mean NodeBase.same_as or NodeBase.equal. + + Parameters + ---------- + a : Expr + Left operand. + + b : Expr + Right operand. + """ + def __init__(self, a, b): + self.a = a + self.b = b + + def __nonzero__(self): + return self.a.same_as(self.b) + + def __bool__(self): + return self.__nonzero__() + + def asnode(self): + """Convert node.""" + return _make.EQ(self.a, self.b) + + +class NotEqualOp(NodeGeneric, ExprOp): + """Deferred NE operator. + + This is used to support sugar that a != b can either + mean not NodeBase.same_as or make.NE. + + Parameters + ---------- + a : Expr + Left operand. + + b : Expr + Right operand. + """ + def __init__(self, a, b): + self.a = a + self.b = b + + def __nonzero__(self): + return not self.a.same_as(self.b) + + def __bool__(self): + return self.__nonzero__() + + def asnode(self): + """Convert node.""" + return _make.NE(self.a, self.b) + + class Expr(ExprOp, NodeBase): """Base class of all tvm Expressions""" # In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = NodeBase.__hash__ + class ConstExpr(Expr): pass @@ -215,19 +274,11 @@ class Max(BinaryOpExpr): @register_node class EQ(CmpExpr): - def __nonzero__(self): - return self.a.same_as(self.b) - - def __bool__(self): - return self.__nonzero__() + pass @register_node class NE(CmpExpr): - def __nonzero__(self): - return not self.a.same_as(self.b) - - def __bool__(self): - return self.__nonzero__() + pass @register_node class LT(CmpExpr): diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index c5cc192a3f33..86b43c3f7980 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -31,6 +31,7 @@ def test_if(): A[0] = A[i] + 2 body = ib.get() + assert A == A assert isinstance(body, tvm.stmt.For) body = body.body assert isinstance(body, tvm.stmt.IfThenElse) @@ -42,6 +43,7 @@ def test_prefetch(): A = tvm.placeholder((10, 20), name="A") ib = tvm.ir_builder.create() n = tvm.var("n") + with ib.for_range(0, n, name="i") as i: ib.emit( tvm.make.Prefetch(