From c4bbdbc70fcfd7a5c16d4af62fe985b5a2c4822e Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Tue, 1 Apr 2025 17:18:07 +0530 Subject: [PATCH 01/10] Add support for logaddexp core operator --- include/tvm/tir/expr.h | 19 ++++++++ include/tvm/tir/expr_functor.h | 4 ++ include/tvm/tir/op.h | 11 +++++ include/tvm/topi/broadcast.h | 16 +++++++ .../torch/exported_program_translator.py | 1 + python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/binary.py | 19 ++++++++ .../relax/transform/legalize_ops/binary.py | 1 + python/tvm/te/__init__.py | 2 +- python/tvm/tir/__init__.py | 4 +- python/tvm/tir/expr.py | 19 ++++++++ python/tvm/tir/op.py | 44 +++++++++++++++++++ python/tvm/topi/broadcast.py | 19 ++++++++ src/relax/op/tensor/binary.cc | 1 + src/relax/op/tensor/binary.h | 3 ++ src/tir/ir/expr.cc | 9 ++++ src/tir/ir/expr_functor.cc | 2 + src/tir/ir/tir_visitor_with_path.cc | 1 + src/tir/ir/tir_visitor_with_path.h | 1 + src/tir/op/op.cc | 8 ++++ src/tir/transforms/lower_intrin.cc | 8 ++++ src/topi/broadcast.cc | 1 + 22 files changed, 191 insertions(+), 3 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 06ee75070ce7..6491e5c427ce 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -261,6 +261,25 @@ class FloorDiv : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode); }; +/*! + * \brief LogAddExp operation, computes log(exp(a) + exp(b)). + */ +class LogAddExpNode : public BinaryOpNode { + public: + static constexpr const char* _type_key = "tir.LogAddExp"; + }; + + /*! + * \brief Managed reference to LogAddExpNode. + * \sa LogAddExpNode + */ + class LogAddExp : public PrimExpr { + public: + TVM_DLL LogAddExp(PrimExpr a, PrimExpr b, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(LogAddExp, PrimExpr, LogAddExpNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LogAddExpNode); + }; + /*! \brief The remainder of the floordiv */ class FloorModNode : public BinaryOpNode { public: diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index dfa9d7e1e346..0804f5a76847 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -128,6 +128,7 @@ class ExprFunctor { virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LogAddExpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -170,6 +171,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(DivNode); IR_EXPR_FUNCTOR_DISPATCH(ModNode); IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode); + IR_EXPR_FUNCTOR_DISPATCH(LogAddExpNode); IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); IR_EXPR_FUNCTOR_DISPATCH(MinNode); IR_EXPR_FUNCTOR_DISPATCH(MaxNode); @@ -221,6 +223,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor { void VisitExpr_(const DivNode* op) override; void VisitExpr_(const ModNode* op) override; void VisitExpr_(const FloorDivNode* op) override; + void VisitExpr_(const LogAddExpNode* op) override; void VisitExpr_(const FloorModNode* op) override; void VisitExpr_(const MinNode* op) override; void VisitExpr_(const MaxNode* op) override; @@ -266,6 +269,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const DivNode* op) override; PrimExpr VisitExpr_(const ModNode* op) override; PrimExpr VisitExpr_(const FloorDivNode* op) override; + PrimExpr VisitExpr_(const LogAddExpNode* op) override; PrimExpr VisitExpr_(const FloorModNode* op) override; PrimExpr VisitExpr_(const MinNode* op) override; PrimExpr VisitExpr_(const MaxNode* op) override; diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index cfbd44529515..ce7a425c94f9 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -394,6 +394,15 @@ TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span()); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span()); +/*! + * \brief Compute log(exp(a) + exp(b)). + * + * \param a Left operand. + * \param b Right operand. + * \param span The location of this operation in the source. + * \return The result expression. + */ +TVM_DLL PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span = Span()); /*! * \brief compute ceil(a / b) * @@ -404,6 +413,7 @@ TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span()); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ + TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span()); /*! * \brief compute the remainder of floordiv @@ -1071,6 +1081,7 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncdiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floordiv); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(logaddexp); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floormod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(right_shift); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(left_shift); // NOLINT(*) diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index d27b6f1a3cfe..4a727514ee2f 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -257,6 +257,22 @@ TOPI_DEFINE_BCAST_OP(floor_divide, { } }); +/*! + * \fn log_add_exp + * \brief Compute log(exp(A) + exp(B)) with auto-broadcasting. + * + * This operation is useful for numerically stable log-sum-exp computations, + * which frequently appear in probabilistic and statistical models. + * + * \param A The first input tensor, or Expr. + * \param B The second input tensor, or Expr. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return The computed log-sum-exp result. + */ +TOPI_DEFINE_BCAST_OP(log_add_exp, { return tvm::logaddexp(a, b); }); + /*! * \fn trunc divide * \brief Compute trunc(A / B) with auto-broadcasting. diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 97ccc6393cbb..7bb49d0e375a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -295,6 +295,7 @@ def create_convert_map( "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), + "logaddexp.default": self._binary_op(relax.op.log_add_exp,torch.logaddexp), "ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge), "ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge), "gt.Scalar": self._binary_op(relax.op.greater, operator.gt), diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 97f18a239640..ddfdfc2b05d8 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -50,6 +50,7 @@ divide, equal, floor_divide, + log_add_exp, floor_mod, greater, greater_equal, diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 7a41c8b0953c..776e1c99dae1 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -85,6 +85,25 @@ def floor_divide(x1: Expr, x2: Expr) -> Expr: return _ffi_api.floor_divide(x1, x2) # type: ignore +def log_add_exp(x1: Expr, x2: Expr) -> Expr: + """ + Compute the log of the sum of exponentials of the inputs, element-wise. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + Expr + The element-wise log-sum-exp of `x1` and `x2`. + """ + return _ffi_api.log_add_exp(x1, x2) + + def multiply(x1: Expr, x2: Expr) -> Expr: """Multiplication with numpy-style broadcasting. diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index 41e317f1e0ef..6883ffbda39e 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -44,6 +44,7 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.add", _binary(topi.add)) register_legalize("relax.divide", _binary(topi.divide)) register_legalize("relax.floor_divide", _binary(topi.floor_divide)) +register_legalize("relax.log_add_exp",_binary(topi.log_add_exp)) register_legalize("relax.multiply", _binary(topi.multiply)) register_legalize("relax.power", _binary(topi.power)) register_legalize("relax.subtract", _binary(topi.subtract)) diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index b31853bea666..362419bebf58 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -24,7 +24,7 @@ from tvm.tir import asin, asinh, acos, acosh, atan, atanh from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else from tvm.tir import isnan, isfinite, isinf -from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod +from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, logaddexp from tvm.tir import comm_reducer, min, max, sum from tvm.tir import add, subtract, multiply diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 4f56ec3c15bc..535dd5532770 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -23,7 +23,7 @@ from .data_layout import Layout, BijectiveLayout, bijective_layout, layout from .expr import convert from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast -from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod +from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod, LogAddExp from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, CommReducer @@ -90,7 +90,7 @@ from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else from .op import likely, isnan, isnullptr, isfinite, isinf, copysign -from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv +from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv, logaddexp from .op import comm_reducer, min, max, sum from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 6cd4302133c5..72feb63b42eb 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -801,6 +801,25 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore +@tvm._ffi.register_object("tir.LogAddExp") +class LogAddExp(BinaryOpExpr): + """LogAddExp node. + + Parameters + ---------- + a : PrimExpr + The left hand operand. + + b : PrimExpr + The right hand operand. + + span : Optional[Span] + The location of this expression in the source code. + """ + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.LogAddExp, a, b, span) # type: ignore + + @tvm._ffi.register_object("tir.FloorMod") class FloorMod(BinaryOpExpr): """FloorMod node. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 53c92fff86dc..b71bf6b88c08 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3221,6 +3221,50 @@ def floordiv(a, b, span=None): return _ffi_api._OpFloorDiv(a, b, span) # type: ignore +def floordiv(a, b, span=None): + """Compute the floordiv of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api._OpFloorDiv(a, b, span) # type: ignore + + +def logaddexp(a, b, span=None): + """Compute the logaddexp of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api._OpLogAddExp(a, b, span) # type: ignore + + def floormod(a, b, span=None): """Compute the floormod of two expressions. diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index 2b350ff817d9..c748d54b19b7 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -135,6 +135,25 @@ def floor_divide(lhs, rhs): return _cpp.floor_divide(lhs, rhs) +def log_add_exp(lhs,rhs): + """Log-sum-exp operation with auto-broadcasting. + + Parameters + ---------- + x1 : tvm.te.Tensor or Expr + The first input tensor or expression. + x2 : tvm.te.Tensor or Expr + The second input tensor or expression. + + Returns + ------- + ret : tvm.te.Tensor or Expr + Returns an Expr if both operands are Expr. + Otherwise, returns a Tensor. + """ + return _cpp.log_add_exp(lhs,rhs) + + def mod(lhs, rhs): """Modulus with auto-broadcasting diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 4a63993d507c..e7fab8f166e1 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -193,6 +193,7 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index b66eb96f8452..6b106f760d5f 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -70,6 +70,9 @@ Expr divide(Expr x1, Expr x2); /*! \brief Floor division with numpy-style broadcasting. */ Expr floor_divide(Expr x1, Expr x2); +/*! \brief Log Add Exponent with numpy-style broadcasting. */ +Expr log_add_exp(Expr x1, Expr x2); + /*! \brief Multiplication with numpy-style broadcasting. */ Expr multiply(Expr x1, Expr x2); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index b52c85df3575..c91a7a670d5c 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -277,6 +277,15 @@ TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Sp TVM_REGISTER_NODE_TYPE(FloorDivNode); +// LogAddExp +TVM_DEFINE_BINOP_CONSTRUCTOR(LogAddExp); + +TVM_REGISTER_GLOBAL("tir.LogAddExp").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return LogAddExp(a, b, span); +}); + +TVM_REGISTER_NODE_TYPE(LogAddExpNode); + // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 05e333b78ac6..786febab4c60 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -61,6 +61,7 @@ DEFINE_BINOP_VISIT_(MulNode); DEFINE_BINOP_VISIT_(DivNode); DEFINE_BINOP_VISIT_(ModNode); DEFINE_BINOP_VISIT_(FloorDivNode); +DEFINE_BINOP_VISIT_(LogAddExpNode); DEFINE_BINOP_VISIT_(FloorModNode); DEFINE_BINOP_VISIT_(MinNode); DEFINE_BINOP_VISIT_(MaxNode); @@ -182,6 +183,7 @@ DEFINE_BIOP_EXPR_MUTATE_(Mul); DEFINE_BIOP_EXPR_MUTATE_(Div); DEFINE_BIOP_EXPR_MUTATE_(Mod); DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); +DEFINE_BIOP_EXPR_MUTATE_(LogAddExp); DEFINE_BIOP_EXPR_MUTATE_(FloorMod); DEFINE_BIOP_EXPR_MUTATE_(Min); DEFINE_BIOP_EXPR_MUTATE_(Max); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 4f5007aedb3f..ecfa00f07c0a 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -377,6 +377,7 @@ DEFINE_BINOP_VISIT_(MulNode); DEFINE_BINOP_VISIT_(DivNode); DEFINE_BINOP_VISIT_(ModNode); DEFINE_BINOP_VISIT_(FloorDivNode); +DEFINE_BINOP_VISIT_(LogAddExpNode); DEFINE_BINOP_VISIT_(FloorModNode); DEFINE_BINOP_VISIT_(MinNode); DEFINE_BINOP_VISIT_(MaxNode); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 61441541da32..1f3a4b08450c 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -131,6 +131,7 @@ class TIRVisitorWithPath : protected ExprFunctora).as()); + PrimExpr exp_b = VisitExpr_(tvm::exp(op->b).as()); + PrimExpr sum = Add(exp_a, exp_b); + PrimExpr log_sum = VisitExpr_(tvm::log(sum).as()); + return log_sum; + } + PrimExpr VisitExpr_(const FloorModNode* op) final { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index f6a28c7722af..2105172aed40 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -52,6 +52,7 @@ TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply); TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide); TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide); +TOPI_REGISTER_BCAST_OP("topi.log_add_exp", topi::log_add_exp); TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod); TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod); TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum); From 908dce8a40454b6827a4dcd5c6682a5e64b63a24 Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Thu, 3 Apr 2025 10:01:53 +0530 Subject: [PATCH 02/10] Add test script for logaddexp --- include/tvm/tir/expr.h | 2 +- python/tvm/relax/op/binary.py | 2 +- tests/python/relax/test_frontend_from_exported_program.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 6491e5c427ce..1ef8f1a606e2 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -268,7 +268,7 @@ class LogAddExpNode : public BinaryOpNode { public: static constexpr const char* _type_key = "tir.LogAddExp"; }; - + /*! * \brief Managed reference to LogAddExpNode. * \sa LogAddExpNode diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 776e1c99dae1..d18aac863535 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -88,7 +88,7 @@ def floor_divide(x1: Expr, x2: Expr) -> Expr: def log_add_exp(x1: Expr, x2: Expr) -> Expr: """ Compute the log of the sum of exponentials of the inputs, element-wise. - + Parameters ---------- x1 : Expr diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 98f0f1d9cac6..396bff0c3bed 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -840,6 +840,7 @@ def main( (operator.mul, R.multiply), (operator.truediv, R.divide), (operator.floordiv, R.floor_divide), + (torch.logaddexp, R.log_add_exp), (operator.pow, R.power), (operator.mod, R.mod), (operator.and_, R.bitwise_and), From f46adde48264da454bd16bca9782ef81aa07c9fc Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Thu, 3 Apr 2025 11:03:31 +0530 Subject: [PATCH 03/10] Add fix for lint issues --- .../torch/exported_program_translator.py | 2 +- .../relax/transform/legalize_ops/binary.py | 2 +- python/tvm/tir/expr.py | 1 + python/tvm/tir/op.py | 22 ------------------- python/tvm/topi/broadcast.py | 4 ++-- 5 files changed, 5 insertions(+), 26 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7bb49d0e375a..733649264ae1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -295,7 +295,7 @@ def create_convert_map( "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), - "logaddexp.default": self._binary_op(relax.op.log_add_exp,torch.logaddexp), + "logaddexp.default": self._binary_op(relax.op.log_add_exp, torch.logaddexp), "ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge), "ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge), "gt.Scalar": self._binary_op(relax.op.greater, operator.gt), diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index 6883ffbda39e..1acbddb2190b 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -44,7 +44,7 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.add", _binary(topi.add)) register_legalize("relax.divide", _binary(topi.divide)) register_legalize("relax.floor_divide", _binary(topi.floor_divide)) -register_legalize("relax.log_add_exp",_binary(topi.log_add_exp)) +register_legalize("relax.log_add_exp", _binary(topi.log_add_exp)) register_legalize("relax.multiply", _binary(topi.multiply)) register_legalize("relax.power", _binary(topi.power)) register_legalize("relax.subtract", _binary(topi.subtract)) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 72feb63b42eb..e84d4122d137 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -816,6 +816,7 @@ class LogAddExp(BinaryOpExpr): span : Optional[Span] The location of this expression in the source code. """ + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LogAddExp, a, b, span) # type: ignore diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index b71bf6b88c08..3770a8be5fd2 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3221,28 +3221,6 @@ def floordiv(a, b, span=None): return _ffi_api._OpFloorDiv(a, b, span) # type: ignore -def floordiv(a, b, span=None): - """Compute the floordiv of two expressions. - - Parameters - ---------- - a : PrimExpr - The left hand operand - - b : PrimExpr - The right hand operand - - span : Optional[Span] - The location of this operator in the source. - - Returns - ------- - res : PrimExpr - The result expression. - """ - return _ffi_api._OpFloorDiv(a, b, span) # type: ignore - - def logaddexp(a, b, span=None): """Compute the logaddexp of two expressions. diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index c748d54b19b7..e2982ecfc21b 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -135,7 +135,7 @@ def floor_divide(lhs, rhs): return _cpp.floor_divide(lhs, rhs) -def log_add_exp(lhs,rhs): +def log_add_exp(lhs, rhs): """Log-sum-exp operation with auto-broadcasting. Parameters @@ -151,7 +151,7 @@ def log_add_exp(lhs,rhs): Returns an Expr if both operands are Expr. Otherwise, returns a Tensor. """ - return _cpp.log_add_exp(lhs,rhs) + return _cpp.log_add_exp(lhs, rhs) def mod(lhs, rhs): From 2c8e1cc7b2e768be550079e995ba2a11c1f3d326 Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Thu, 3 Apr 2025 11:08:42 +0530 Subject: [PATCH 04/10] Adjust trailing spaces --- python/tvm/tir/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index e84d4122d137..3ee4cd0c7f35 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -816,7 +816,7 @@ class LogAddExp(BinaryOpExpr): span : Optional[Span] The location of this expression in the source code. """ - + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LogAddExp, a, b, span) # type: ignore From 26b3ffdeddbf08ac2e8bf765b86c3a216b10fe77 Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Thu, 3 Apr 2025 12:02:48 +0530 Subject: [PATCH 05/10] Adjust leading whitespace --- include/tvm/tir/expr.h | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 1ef8f1a606e2..8ff224bc6660 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -265,20 +265,20 @@ class FloorDiv : public PrimExpr { * \brief LogAddExp operation, computes log(exp(a) + exp(b)). */ class LogAddExpNode : public BinaryOpNode { - public: - static constexpr const char* _type_key = "tir.LogAddExp"; - }; + public: + static constexpr const char* _type_key = "tir.LogAddExp"; +}; - /*! +/*! * \brief Managed reference to LogAddExpNode. * \sa LogAddExpNode - */ - class LogAddExp : public PrimExpr { - public: - TVM_DLL LogAddExp(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(LogAddExp, PrimExpr, LogAddExpNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(LogAddExpNode); - }; + */ +class LogAddExp : public PrimExpr { + public: + TVM_DLL LogAddExp(PrimExpr a, PrimExpr b, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(LogAddExp, PrimExpr, LogAddExpNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LogAddExpNode); +}; /*! \brief The remainder of the floordiv */ class FloorModNode : public BinaryOpNode { From 0847e4c218bec9fc85e8de51c0af59681dcdc1e9 Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Thu, 3 Apr 2025 13:16:34 +0530 Subject: [PATCH 06/10] Add fix for lint inssues --- include/tvm/tir/expr.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 8ff224bc6660..a7cd090e5ab9 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -270,8 +270,8 @@ class LogAddExpNode : public BinaryOpNode { }; /*! - * \brief Managed reference to LogAddExpNode. - * \sa LogAddExpNode + * \brief Managed reference to LogAddExpNode. + * \sa LogAddExpNode */ class LogAddExp : public PrimExpr { public: From 7e6f07e88bab771b30910bb0b7f08636baa1bca8 Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Thu, 3 Apr 2025 15:02:39 +0530 Subject: [PATCH 07/10] Add fix for logaddexp test script --- python/tvm/script/ir_builder/relax/ir.py | 2 ++ .../test_frontend_from_exported_program.py | 27 ++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ddc534cf6086..6fa3cc61cbbc 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -112,6 +112,7 @@ less_equal, linear, log, + log_add_exp, logical_and, logical_not, logical_or, @@ -794,6 +795,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "less_equal", "linear", "log", + "log_add_exp", "logical_and", "logical_not", "logical_or", diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 396bff0c3bed..314e0ea8fb8f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -682,6 +682,32 @@ def main( verify_model(LeakyReLU1(), example_args, {}, expected) +def test_logaddexp(): + class LogAddExp(Module): + def forward(self, input1, input2): + return torch.logaddexp(input1, input2) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_2: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log_add_exp(input_1, input_2) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 3, 10, 10, dtype=torch.float32), + torch.randn(1, 3, 10, 10, dtype=torch.float32) + ) + verify_model(LogAddExp(), example_args, {}, expected) + + def test_logsoftmax(): class LogSoftmax(Module): def __init__(self): @@ -840,7 +866,6 @@ def main( (operator.mul, R.multiply), (operator.truediv, R.divide), (operator.floordiv, R.floor_divide), - (torch.logaddexp, R.log_add_exp), (operator.pow, R.power), (operator.mod, R.mod), (operator.and_, R.bitwise_and), From 8c83f06775d990cc6ef9d893886c2118e908f8f0 Mon Sep 17 00:00:00 2001 From: AishwaryaElango Date: Thu, 3 Apr 2025 15:25:14 +0530 Subject: [PATCH 08/10] Fix lint issues --- tests/python/relax/test_frontend_from_exported_program.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 314e0ea8fb8f..46029c856e5c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -692,7 +692,7 @@ class expected: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), - input_2: R.Tensor((1, 3, 10, 10), dtype="float32") + input_2: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): @@ -703,7 +703,7 @@ def main( example_args = ( torch.randn(1, 3, 10, 10, dtype=torch.float32), - torch.randn(1, 3, 10, 10, dtype=torch.float32) + torch.randn(1, 3, 10, 10, dtype=torch.float32), ) verify_model(LogAddExp(), example_args, {}, expected) From a84e5f15e98ba06d809423a90afcd4c2ac28bf9e Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Tue, 15 Apr 2025 16:38:29 +0000 Subject: [PATCH 09/10] decomposition at op level --- include/tvm/tir/expr.h | 19 ------------------- include/tvm/tir/expr_functor.h | 4 ---- include/tvm/topi/broadcast.h | 2 +- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/expr.py | 20 -------------------- src/tir/ir/expr.cc | 11 ----------- src/tir/ir/expr_functor.cc | 2 -- src/tir/ir/tir_visitor_with_path.cc | 1 - src/tir/ir/tir_visitor_with_path.h | 1 - src/tir/op/op.cc | 4 +++- src/tir/transforms/lower_intrin.cc | 8 -------- 11 files changed, 5 insertions(+), 69 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index a7cd090e5ab9..06ee75070ce7 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -261,25 +261,6 @@ class FloorDiv : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode); }; -/*! - * \brief LogAddExp operation, computes log(exp(a) + exp(b)). - */ -class LogAddExpNode : public BinaryOpNode { - public: - static constexpr const char* _type_key = "tir.LogAddExp"; -}; - -/*! - * \brief Managed reference to LogAddExpNode. - * \sa LogAddExpNode - */ -class LogAddExp : public PrimExpr { - public: - TVM_DLL LogAddExp(PrimExpr a, PrimExpr b, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(LogAddExp, PrimExpr, LogAddExpNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(LogAddExpNode); -}; - /*! \brief The remainder of the floordiv */ class FloorModNode : public BinaryOpNode { public: diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index 0804f5a76847..dfa9d7e1e346 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -128,7 +128,6 @@ class ExprFunctor { virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LogAddExpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -171,7 +170,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(DivNode); IR_EXPR_FUNCTOR_DISPATCH(ModNode); IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode); - IR_EXPR_FUNCTOR_DISPATCH(LogAddExpNode); IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); IR_EXPR_FUNCTOR_DISPATCH(MinNode); IR_EXPR_FUNCTOR_DISPATCH(MaxNode); @@ -223,7 +221,6 @@ class TVM_DLL ExprVisitor : public ExprFunctor { void VisitExpr_(const DivNode* op) override; void VisitExpr_(const ModNode* op) override; void VisitExpr_(const FloorDivNode* op) override; - void VisitExpr_(const LogAddExpNode* op) override; void VisitExpr_(const FloorModNode* op) override; void VisitExpr_(const MinNode* op) override; void VisitExpr_(const MaxNode* op) override; @@ -269,7 +266,6 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const DivNode* op) override; PrimExpr VisitExpr_(const ModNode* op) override; PrimExpr VisitExpr_(const FloorDivNode* op) override; - PrimExpr VisitExpr_(const LogAddExpNode* op) override; PrimExpr VisitExpr_(const FloorModNode* op) override; PrimExpr VisitExpr_(const MinNode* op) override; PrimExpr VisitExpr_(const MaxNode* op) override; diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 4a727514ee2f..9be7256b446e 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -271,7 +271,7 @@ TOPI_DEFINE_BCAST_OP(floor_divide, { * * \return The computed log-sum-exp result. */ -TOPI_DEFINE_BCAST_OP(log_add_exp, { return tvm::logaddexp(a, b); }); +TOPI_DEFINE_BCAST_OP(log_add_exp, { return logaddexp(a, b); }); /*! * \fn trunc divide diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 535dd5532770..5ceb48127038 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -23,7 +23,7 @@ from .data_layout import Layout, BijectiveLayout, bijective_layout, layout from .expr import convert from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast -from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod, LogAddExp +from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, CommReducer diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 3ee4cd0c7f35..6cd4302133c5 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -801,26 +801,6 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.LogAddExp") -class LogAddExp(BinaryOpExpr): - """LogAddExp node. - - Parameters - ---------- - a : PrimExpr - The left hand operand. - - b : PrimExpr - The right hand operand. - - span : Optional[Span] - The location of this expression in the source code. - """ - - def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.LogAddExp, a, b, span) # type: ignore - - @tvm._ffi.register_object("tir.FloorMod") class FloorMod(BinaryOpExpr): """FloorMod node. diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index c91a7a670d5c..defd74284f02 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -275,17 +275,6 @@ TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Sp return FloorDiv(a, b, span); }); -TVM_REGISTER_NODE_TYPE(FloorDivNode); - -// LogAddExp -TVM_DEFINE_BINOP_CONSTRUCTOR(LogAddExp); - -TVM_REGISTER_GLOBAL("tir.LogAddExp").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { - return LogAddExp(a, b, span); -}); - -TVM_REGISTER_NODE_TYPE(LogAddExpNode); - // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 786febab4c60..05e333b78ac6 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -61,7 +61,6 @@ DEFINE_BINOP_VISIT_(MulNode); DEFINE_BINOP_VISIT_(DivNode); DEFINE_BINOP_VISIT_(ModNode); DEFINE_BINOP_VISIT_(FloorDivNode); -DEFINE_BINOP_VISIT_(LogAddExpNode); DEFINE_BINOP_VISIT_(FloorModNode); DEFINE_BINOP_VISIT_(MinNode); DEFINE_BINOP_VISIT_(MaxNode); @@ -183,7 +182,6 @@ DEFINE_BIOP_EXPR_MUTATE_(Mul); DEFINE_BIOP_EXPR_MUTATE_(Div); DEFINE_BIOP_EXPR_MUTATE_(Mod); DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); -DEFINE_BIOP_EXPR_MUTATE_(LogAddExp); DEFINE_BIOP_EXPR_MUTATE_(FloorMod); DEFINE_BIOP_EXPR_MUTATE_(Min); DEFINE_BIOP_EXPR_MUTATE_(Max); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index ecfa00f07c0a..4f5007aedb3f 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -377,7 +377,6 @@ DEFINE_BINOP_VISIT_(MulNode); DEFINE_BINOP_VISIT_(DivNode); DEFINE_BINOP_VISIT_(ModNode); DEFINE_BINOP_VISIT_(FloorDivNode); -DEFINE_BINOP_VISIT_(LogAddExpNode); DEFINE_BINOP_VISIT_(FloorModNode); DEFINE_BINOP_VISIT_(MinNode); DEFINE_BINOP_VISIT_(MaxNode); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 1f3a4b08450c..61441541da32 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -131,7 +131,6 @@ class TIRVisitorWithPath : protected ExprFunctora).as()); - PrimExpr exp_b = VisitExpr_(tvm::exp(op->b).as()); - PrimExpr sum = Add(exp_a, exp_b); - PrimExpr log_sum = VisitExpr_(tvm::log(sum).as()); - return log_sum; - } - PrimExpr VisitExpr_(const FloorModNode* op) final { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); From f54d33daf3d2defc2d7c04cd92eb9f77cb521309 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Wed, 16 Apr 2025 05:57:40 +0000 Subject: [PATCH 10/10] unity check --- src/tir/ir/expr.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index defd74284f02..b52c85df3575 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -275,6 +275,8 @@ TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Sp return FloorDiv(a, b, span); }); +TVM_REGISTER_NODE_TYPE(FloorDivNode); + // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod);