From 4f7e1dbcba28a6dfa5fc78b21392f5ff624025bd Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 14 Jan 2020 11:29:40 -0800 Subject: [PATCH 1/3] [REFACTOR][IR] Unify IntImm and UIntImm This PR unifies UIntImm and IntImm to simplify the codebase. Unsigned integer constants will also be stored as IntImm. For uint constant that does not fit into int64(rare case), we introduced an intrinsic tvm_big_uint_imm to construct such intgers by its lower and higher 32bits. --- include/tvm/expr.h | 48 ++++++------------- include/tvm/expr_operator.h | 27 +++++++++-- include/tvm/ir.h | 10 ++++ include/tvm/ir/expr.h | 50 ++++++++++++++++++++ python/tvm/api.py | 3 ++ src/api/api_ir.cc | 1 - src/api/api_lang.cc | 3 ++ src/arithmetic/canonical_simplify.cc | 8 ++-- src/arithmetic/const_fold.h | 18 +++---- src/arithmetic/const_int_bound.cc | 2 +- src/arithmetic/int_set.cc | 2 +- src/arithmetic/modular_set.cc | 2 +- src/arithmetic/pattern_match.h | 6 +-- src/arithmetic/rewrite_simplify.cc | 26 +++++----- src/autotvm/touch_extractor.cc | 2 +- src/codegen/codegen_c.cc | 21 ++++++-- src/codegen/llvm/codegen_llvm.cc | 6 +++ src/codegen/llvm/codegen_x86_64.cc | 4 +- src/codegen/spirv/codegen_spirv.cc | 6 +++ src/codegen/spirv/ir_builder.cc | 4 +- src/codegen/spirv/ir_builder.h | 4 +- src/ir/expr.cc | 19 ++++++++ src/lang/expr.cc | 11 +---- src/lang/expr_operator.cc | 30 +++++++----- src/lang/ir.cc | 2 +- src/pass/arg_binder.cc | 10 ++-- src/pass/lower_intrin.cc | 2 +- src/pass/lower_tvm_builtin.cc | 4 +- src/pass/make_api.cc | 6 +-- src/pass/tensor_core.cc | 14 +++--- src/relay/backend/compile_engine.cc | 8 ++-- src/relay/ir/expr.cc | 2 +- src/relay/op/tensor/transform.cc | 2 +- tests/python/unittest/test_codegen_device.py | 27 +++++++++++ tests/python/unittest/test_codegen_llvm.py | 20 ++++++++ 35 files changed, 282 insertions(+), 128 deletions(-) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index faae303d95dd..62806c667e61 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -115,56 +115,38 @@ class Var : public PrimExpr { using ContainerType = VarNode; }; -class Integer; -/*! \brief ExprNode: constant integer. */ -class IntImmNode : public PrimExprNode { - public: - /*! \brief the Internal value. */ - int64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static Integer make(DataType t, int64_t value); - - static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); -}; - /*! - * \brief Container of constant integer (IntImm). + * \brief Container of constant int that adds more constructors. * * This is used to store and automate type check * attributes that must be constant integer. + * + * \sa IntImm */ -class Integer : public PrimExpr { +class Integer : public IntImm { public: - Integer() : PrimExpr() {} + Integer() {} /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : PrimExpr(node) {} + explicit Integer(ObjectPtr node) : IntImm(node) {} /*! * \brief Construct integer from int value. */ - Integer(int value) : PrimExpr(value) {} // NOLINT(*) + Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) + /*! + * \brief Construct integer from int imm. + * \param other The other value. + */ + Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) /*! * \brief Assign an expression to integer. * \param other another expression. */ - Integer& operator=(const Integer& other) { - data_ = other.data_; + Integer& operator=(const IntImm& other) { + data_ = ObjectRef::GetDataPtr(other); return *this; } - /*! - * \brief Get pointer to the internal value. - * \return the content of the integer. - */ - const IntImmNode* operator->() const { - return static_cast(get()); - } /*! * \brief convert to int64_t */ @@ -173,8 +155,6 @@ class Integer : public PrimExpr { << " Trying to reference a null Integer"; return (*this)->value; } - /*! \brief type indicate the container type */ - using ContainerType = IntImmNode; }; /*! \brief range over one dimension */ diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 2d8f37855856..cbcb72a151e4 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -597,6 +597,15 @@ TVM_DLL PrimExpr nearbyint(PrimExpr x); */ TVM_DLL PrimExpr trunc(PrimExpr x); +/*! + * \brief Construct a big uint constant by its low 32 bits and high 32bits. + * \param dtype The final data type. + * \param low The lower 32 bits. + * \param high The higher 32 bits. + * \return The constructed expression. + */ +TVM_DLL PrimExpr BigUIntImm(DataType dtype, int64_t low, int64_t high); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ @@ -675,15 +684,27 @@ inline bool is_no_op(const Stmt& stmt) { template inline PrimExpr MakeConstScalar(DataType t, ValueType value) { - if (t.is_int()) return ir::IntImmNode::make(t, static_cast(value)); - if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast(value)); + if (t.is_int()) return IntImm(t, static_cast(value)); + if (t.is_uint()) { + // Use IntImm if it is a small integer + uint64_t uval = static_cast(value); + if (uval <= static_cast(std::numeric_limits::max())) { + return IntImm(t, static_cast(value)); + } else { + uint64_t mask = (static_cast(1) << 32U) - 1U; + uint64_t low = uval & mask; + uint64_t high = uval >> 32U; + return BigUIntImm(t, static_cast(low), static_cast(high)); + } + } if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(kCustomBegin)) + if (static_cast(t.code()) >= static_cast(kCustomBegin)) { return ir::FloatImmNode::make(t, static_cast(value)); + } LOG(FATAL) << "cannot make const for type " << t; return PrimExpr(); } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 84039485ae69..c637d055928c 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1422,6 +1422,16 @@ inline bool IsPragmaKey(const std::string& attr_key) { /*! \brief namespace of TVM Intrinsic functions */ namespace intrinsic { +/*! + * \brief See pesudo code + * + * Construct a big uint that may not be representable by int64 + * + * Expr tvm_big_uint_imm(uint32_t v0, uin32_t v1) { + * return (v1 << 32) | v0; + * } + */ +constexpr const char* tvm_big_uint_imm = "tvm_big_uint_imm"; /*! * \brief See pesudo code * diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 7b42678ee103..222f059ab76d 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -131,6 +131,56 @@ class PrimExpr : public BaseExpr { using ContainerType = PrimExprNode; }; +/*! + * \brief Constant integer literals in the program. + * \sa IntImm + */ +class IntImmNode : public PrimExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "IntImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to IntImmNode. + * + * \sa IntImmNode + */ +class IntImm : public PrimExpr { + public: + /*! + * \brief Constructor + */ + IntImm() {} + /*! + * \brief constructor from node. + */ + explicit IntImm(ObjectPtr node) : PrimExpr(node) {} + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + TVM_DLL IntImm(DataType dtype, int64_t value); + /*! + * \brief Get pointer to the internal value. + * \return the content of the integer. + */ + const IntImmNode* operator->() const { + return static_cast(get()); + } + /*! \brief type indicate the container type */ + using ContainerType = IntImmNode; +}; + /*! * \brief Base node of all non-primitive expressions. * diff --git a/python/tvm/api.py b/python/tvm/api.py index 7395d3524709..9afa0cc0609e 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -92,6 +92,9 @@ def const(value, dtype=None): """ if dtype is None: dtype = _scalar_type_inference(value) + if dtype == "uint64" and value >= (1 << 63): + return _api_internal._BigUIntImm( + dtype, value & ((1 << 32) - 1), value >> 32) return _api_internal._const(value, dtype) diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index ca4823bc6b83..049b6ee38d48 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); -REGISTER_MAKE(IntImm); REGISTER_MAKE(UIntImm); REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 6a8bc58ad7d0..6b0cfdd55bd6 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -53,6 +53,9 @@ TVM_REGISTER_GLOBAL("_const") } }); +TVM_REGISTER_GLOBAL("_BigUIntImm") +.set_body_typed(BigUIntImm); + TVM_REGISTER_GLOBAL("_str") .set_body_typed(ir::StringImmNode::make); diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 5f721d7a1f94..90c6e48ded1e 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -737,7 +737,7 @@ VisitExpr_(const DivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -797,7 +797,7 @@ VisitExpr_(const FloorDivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -905,7 +905,7 @@ VisitExpr_(const ModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -975,7 +975,7 @@ VisitExpr_(const FloorModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 55c156d898f9..2bee70ed557a 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -103,7 +103,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value); + if (pa && pb) return IntImm(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value); @@ -117,7 +117,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value); + if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; @@ -129,7 +129,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value); + if (pa && pb) return IntImm(rtype, pa->value * pb->value); if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; @@ -159,7 +159,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { // due to division and mod can have different modes // NOTE: this will assumes truc div. CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, pa->value / pb->value); + return IntImm(rtype, pa->value / pb->value); } if (pa) { if (pa->value == 0) return a; @@ -185,7 +185,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, pa->value % pb->value); + return IntImm(rtype, pa->value % pb->value); } if (pa) { if (pa->value == 0) return a; @@ -204,7 +204,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value)); + return IntImm(rtype, arith::floordiv(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -230,7 +230,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value)); + return IntImm(rtype, arith::floormod(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -247,7 +247,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; @@ -258,7 +258,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index a041e40abf46..3a85c39aa3f0 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -496,7 +496,7 @@ class ConstIntBoundAnalyzer::Impl : */ static std::vector DetectBoundInfo(const PrimExpr& cond) { PVar x, y; - PVar c; + PVar c; // NOTE: canonical form always use <= or < if ((c <= x).Match(cond)) { return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))}; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index ceaa976469e8..9b1ab3d63907 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -476,7 +476,7 @@ class IntervalSetEvaluator : IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); IntervalSet base = Eval(op->base); - PVar stride; + PVar stride; if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 01dd2e8e499e..972c5148134f 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -109,7 +109,7 @@ class ModularSetAnalyzer::Impl : // Detect useful constraints and use them in the analysis scope. std::function EnterConstraint(const PrimExpr& constraint) { PVar var; - PVar coeff, base; + PVar coeff, base; // pattern match interesting constraints if ((truncmod(var, coeff) == base).Match(constraint) || (floormod(var, coeff) == base).Match(constraint)) { diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 733dcf41ce94..a236e65a8312 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -45,7 +45,7 @@ * } * * tvm::Var tx, ty; - * arith::PVar c; + * arith::PVar c; * arith::PVar v; * // We can match integer and Var, both of which are * // special case container of Expr @@ -140,9 +140,9 @@ class PEqualChecker { }; template<> -class PEqualChecker { +class PEqualChecker { public: - bool operator()(const Integer& lhs, const Integer& rhs) const { + bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 94d951da51db..e6e1524604ce 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -124,7 +124,7 @@ VisitExpr_(const AddNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -239,7 +239,7 @@ VisitExpr_(const SubNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -438,7 +438,7 @@ VisitExpr_(const MulNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -477,7 +477,7 @@ VisitExpr_(const DivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -700,7 +700,7 @@ VisitExpr_(const ModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -789,7 +789,7 @@ VisitExpr_(const FloorDivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -934,7 +934,7 @@ VisitExpr_(const FloorModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -1004,7 +1004,7 @@ VisitExpr_(const MinNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1189,7 +1189,7 @@ VisitExpr_(const MaxNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1362,7 +1362,7 @@ VisitExpr_(const EQNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1; + PVar c1; PVar lanes; // vector rule @@ -1416,7 +1416,7 @@ VisitExpr_(const LTNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1597,7 +1597,7 @@ VisitExpr_(const AndNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { @@ -1646,7 +1646,7 @@ VisitExpr_(const OrNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index cf138edd494e..55ed36ca9352 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -256,7 +256,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > Array attr{std::string("_attr_"), FloatImmNode::make(DataType::Float(32), trans(fea.length)), - IntImmNode::make(DataType::Int(32), fea.nest_level), + IntImm(DataType::Int(32), fea.nest_level), FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)), FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)), }; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 777ad6203008..eae15248751b 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -372,19 +372,24 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // } } -inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - if (op->dtype == DataType::UInt(32)) { + +inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*) + if (dtype == DataType::UInt(32)) { std::ostringstream temp; - temp << op->value << "U"; + temp << val << "U"; p->MarkConst(temp.str()); os << temp.str(); } else { os << "("; - p->PrintType(op->dtype, os); - os << ")" << op->value; + p->PrintType(dtype, os); + os << ")" << val; } } +inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintUIntConst(op->dtype, op->value, os, p); +} + inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) switch (op->dtype.bits()) { case 64: case 32: { @@ -528,6 +533,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << ")"; } else if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsic(op, " & ", os, this); + } else if (op->is_intrinsic(intrinsic::tvm_big_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + PrintUIntConst(op->dtype, val, os, this); } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsic(op, " ^ ", os, this); } else if (op->is_intrinsic(CallNode::bitwise_or)) { diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index c04a023aefad..20edd0a901a7 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -722,6 +722,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_big_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return llvm::ConstantInt::get(LLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 03656cc70a46..11bda70fb8cf 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -96,8 +96,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { MakeValue( ir::BroadcastNode::make( ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())), - /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)), - /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)), + /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index a749424892e2..ac7423e8ad87 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -285,6 +285,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->is_intrinsic(CallNode::reinterpret)) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_big_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return builder_->UIntImm(builder_->GetSType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return this->CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 6f8d96e148c1..bf43f11cce02 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -342,9 +342,9 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (dtype.type == DataType::UInt(1)) { // bool types. if (*pvalue) { - ib_.Begin(spv::OpConstantTrue).AddSeq(ret); + ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); } else { - ib_.Begin(spv::OpConstantFalse).AddSeq(ret); + ib_.Begin(spv::OpConstantFalse).AddSeq(dtype, ret); } } else { // Integral/floating-point types. diff --git a/src/codegen/spirv/ir_builder.h b/src/codegen/spirv/ir_builder.h index 3843cbb3c6a9..5d25e8634e84 100644 --- a/src/codegen/spirv/ir_builder.h +++ b/src/codegen/spirv/ir_builder.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f698a5d1802e..6d89967416b0 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -26,6 +26,25 @@ namespace tvm { +IntImm::IntImm(DataType dtype, int64_t value) { + CHECK(dtype.is_scalar()) + << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) + << "ValueError: IntImm can only take scalar."; + if (dtype.is_uint()) { + CHECK_GE(value, 0U); + } + ObjectPtr node = make_object(); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("make.IntImm") +.set_body_typed([](DataType dtype, int64_t value) { + return IntImm(dtype, value); +}); + GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); diff --git a/src/lang/expr.cc b/src/lang/expr.cc index a7289369bcd4..55dfb89342a8 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -30,7 +30,7 @@ namespace tvm { PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {} + : PrimExpr(IntImm(DataType::Int(32), value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {} @@ -54,15 +54,6 @@ Range::Range(PrimExpr begin, PrimExpr end) is_zero(begin) ? end : (end - begin))) { } -Integer IntImmNode::make(DataType t, int64_t value) { - CHECK(t.is_int() && t.is_scalar()) - << "ValueError: IntImm can only take scalar."; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return Integer(node); -} - Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index d3875e28c887..6c7c54726eb9 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -35,6 +35,14 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return ir::CastNode::make(t, value); } +PrimExpr BigUIntImm(DataType t, int64_t low, int64_t high) { + return ir::CallNode::make( + t, ir::intrinsic::tvm_big_uint_imm, + {make_const(DataType::UInt(32), low), + make_const(DataType::UInt(32), high)}, + ir::CallNode::PureIntrinsic); +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; @@ -85,11 +93,11 @@ PrimExpr max_value(const DataType& dtype) { CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::max()); + return IntImm(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { int64_t val = 1; val = (val << (dtype.bits() - 1)) - 1; - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { @@ -117,11 +125,11 @@ PrimExpr min_value(const DataType& dtype) { CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::lowest()); + return IntImm(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() < 64) { int64_t val = 1; val = -(val << (dtype.bits() - 1)); - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { return UIntImmNode::make(dtype, 0); @@ -219,7 +227,7 @@ PrimExpr operator-(PrimExpr a) { using ir::FloatImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); - if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value); + if (pa) return IntImm(a.dtype(), -pa->value); if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value); return make_zero(a.dtype()) - a; } @@ -424,7 +432,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -437,7 +445,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -450,7 +458,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic); @@ -460,7 +468,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic); @@ -470,7 +478,7 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic); @@ -494,7 +502,7 @@ PrimExpr abs(PrimExpr x) { using ir::IntImmNode; const IntImmNode* px = x.as(); if (px) { - return ir::IntImmNode::make(x.dtype(), std::abs(px->value)); + return IntImm(x.dtype(), std::abs(px->value)); } return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index ad7f260226bd..5a24e965e780 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -248,7 +248,7 @@ PrimExpr ShuffleNode::make_concat(Array vectors) { int index = 0; for (const PrimExpr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { - indices.push_back(IntImmNode::make(DataType::Int(32), index++)); + indices.push_back(IntImm(DataType::Int(32), index++)); } } return make(vectors, indices); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 2c04de3710fa..612a56664c8c 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -193,7 +193,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, // mark alignment of external bufs init_nest_.emplace_back(AttrStmtNode::make( vptr, ir::attr::storage_alignment, - IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop)); + IntImm(DataType::Int(32), buffer->data_alignment), nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); @@ -206,7 +206,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->shape[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_shape, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } // strides field @@ -228,7 +228,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr svalue = cast( stype, LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -251,7 +251,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, field_name << v_strides->name_hint << '[' << k << ']'; PrimExpr value = cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -270,7 +270,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } } diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index ed8be8bb39fc..5684f4ef785f 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -173,7 +173,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const MaxNode* op) final { using namespace arith; PVar x, y; - PVar c; + PVar c; auto e = GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 8e7f1d86da74..01a97b7878be 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -129,8 +129,8 @@ class BuiltinLower : public StmtExprMutator { {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), - IntImmNode::make(DataType::Int(32), op->dtype.code()), - IntImmNode::make(DataType::Int(32), op->dtype.bits())}, + IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, CallNode::Extern), body); diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index d5c73a2e8a75..5df36d0b2423 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -69,8 +69,8 @@ LoweredFunc MakeAPI(Stmt body, // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, - IntImmNode::make(DataType::Int(32), i), - IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)}; + IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); PrimExpr res = CallNode::make( @@ -117,7 +117,7 @@ LoweredFunc MakeAPI(Stmt body, seq_init.emplace_back(LetStmtNode::make( tcode, LoadNode::make( DataType::Int(32), v_packed_arg_type_ids, - IntImmNode::make(DataType::Int(32), i), const_true(1)), + IntImm(DataType::Int(32), i), const_true(1)), nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index bb57fe8c37d3..956f27c9319d 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -462,7 +462,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -575,7 +575,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -765,7 +765,7 @@ class ThreadIdxMutator : public StmtExprMutator { op = expr.as(); if (op != nullptr) { if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImmNode::make(DataType::Int(32), 0); + PrimExpr zero = IntImm(DataType::Int(32), 0); return zero; } if (op->name_hint == "threadIdx.y") { @@ -934,7 +934,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr stride = strides[strides.size()-2]; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); PrimExpr src = CallNode::make(value->dtype, @@ -984,7 +984,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr dst = it3->second; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); dst = CallNode::make(DataType::Handle(), @@ -1089,7 +1089,7 @@ class TensorCoreIRMutator : public StmtExprMutator { Array strides; for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, shape[j]); } @@ -1097,7 +1097,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } strides.push_back(make_const(DataType::Int(32), 1)); - PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0); + PrimExpr elem_offset = IntImm(DataType::Int(32), 0); CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d4a7cb1f2ad1..e8af5d2226db 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -88,7 +88,7 @@ Array GetShape(const Array& shape) { if (pval != nullptr) { CHECK_LE(pval[0], std::numeric_limits::max()); CHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval)); + res.push_back(IntImm(DataType::Int(32), *pval)); } else if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { @@ -395,7 +395,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { // set inputs for (auto param : prim_func->params) { int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state)); + cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); if (state & kNeedInputData) { for (auto t : param_data_[param]) { cache_node->inputs.push_back(t); @@ -528,7 +528,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { auto ret_type = call_node->checked_type(); Array out_ndims; if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } else { auto rtype = ret_type.as(); // TODO(@icemelon): Allow recursive tuple @@ -536,7 +536,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { for (size_t i = 0; i < rtype->fields.size(); ++i) { auto ttype = rtype->fields[i].as(); CHECK(ttype); - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } } // Call shape function diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 239a33ea642c..e2a48ecceda6 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -56,7 +56,7 @@ TensorType ConstantNode::tensor_type() const { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); shape.push_back( - tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i])); + tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorTypeNode::make(shape, dtype); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index aa643c4d8bc6..162b612c0530 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -852,7 +852,7 @@ bool ArgWhereRel(const Array& types, const auto& input_rank = input_shape.size(); std::vector result_shape; result_shape.push_back(Any::make()); - result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank)); + result_shape.push_back(IntImm(DataType::Int(32), input_rank)); reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32))); return true; } diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 45ecf9539337..592b073767a8 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -18,6 +18,32 @@ from tvm.contrib import util import numpy as np +def test_big_uint_imm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + n = 12 + num_thread = 2 + + A = tvm.compute((n,), lambda *i: tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + xo, xi = s[A].split(A.op.axis[0], factor=num_thread) + s[A].bind(xi, tvm.thread_axis("threadIdx.x")) + s[A].bind(xo, tvm.thread_axis("blockIdx.x")) + + def check_target(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + return + f = tvm.build(s, [A], device) + # launch the kernel. + a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy()[0] == value + 3 + + check_target("cuda") + check_target("vulkan") + + def test_add_pipeline(): n = tvm.var('n') A = tvm.placeholder((n,), name='A') @@ -112,4 +138,5 @@ def check_module_save(device, host="stackvm"): if __name__ == "__main__": + test_big_uint_imm() test_add_pipeline() diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 0e595cd79c97..f21bc33dadd5 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -88,6 +88,25 @@ def test_llvm_lookup_intrin(): fcode = tvm.build(func, None, "llvm") +def test_llvm_big_uintimm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + A = tvm.compute((), lambda : tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + f = tvm.build(s, [A], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.empty((), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy() == value + 3 + + check_llvm() + + def test_llvm_add_pipeline(): nn = 1024 n = tvm.convert(nn) @@ -645,6 +664,7 @@ def vectorizer(op): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) if __name__ == "__main__": + test_llvm_big_uintimm() test_llvm_import() test_alignment() test_rank_zero() From 501a099da2fbf3f4e028de6933b708a2932f03ec Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 14 Jan 2020 13:45:28 -0800 Subject: [PATCH 2/3] [REFACTOR][IR] Remove UIntImm to use IntImm --- include/tvm/attrs.h | 4 -- include/tvm/expr_operator.h | 26 ++--------- include/tvm/ir.h | 17 -------- include/tvm/ir_functor_ext.h | 4 -- python/tvm/autotvm/task/task.py | 4 +- python/tvm/autotvm/util.py | 8 ++-- python/tvm/expr.py | 17 -------- python/tvm/hybrid/calls.py | 2 +- python/tvm/hybrid/parser.py | 4 +- python/tvm/hybrid/util.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 2 +- src/api/api_ir.cc | 1 - src/arithmetic/analyzer.cc | 6 +-- src/arithmetic/const_fold.h | 43 ++++++++----------- src/arithmetic/const_int_bound.cc | 8 ---- src/arithmetic/int_set.cc | 4 -- src/arithmetic/modular_set.cc | 8 ---- src/codegen/codegen_c.cc | 8 +--- src/codegen/codegen_c.h | 1 - src/codegen/codegen_opengl.cc | 5 --- src/codegen/codegen_opengl.h | 1 - src/codegen/llvm/codegen_arm.cc | 22 +++++----- src/codegen/llvm/codegen_llvm.cc | 12 ++---- src/codegen/llvm/codegen_llvm.h | 1 - src/codegen/llvm/intrin_rule_llvm.h | 8 ++-- src/codegen/spirv/codegen_spirv.cc | 7 +-- src/codegen/spirv/codegen_spirv.h | 1 - src/codegen/spirv/intrin_rule_spirv.cc | 2 +- src/codegen/stackvm/codegen_stackvm.cc | 6 --- src/codegen/stackvm/codegen_stackvm.h | 1 - src/contrib/hybrid/codegen_hybrid.cc | 5 +-- src/contrib/hybrid/codegen_hybrid.h | 1 - src/lang/attr_functor.h | 4 -- src/lang/attrs.cc | 11 ----- src/lang/expr_operator.cc | 25 ++--------- src/lang/ir.cc | 14 ------ src/pass/arg_binder.cc | 6 +-- src/pass/ir_deep_compare.cc | 4 -- src/pass/ir_functor.cc | 2 - src/pass/lift_attr_scope.cc | 3 -- src/pass/lower_thread_allreduce.cc | 2 +- src/pass/rewrite_unsafe_select.cc | 1 - src/pass/unroll_loop.cc | 4 -- src/relay/ir/pretty_printer.cc | 4 -- src/relay/pass/type_solver.cc | 2 +- src/relay/qnn/util.h | 12 +----- tests/cpp/pattern_match_test.cc | 4 +- tests/python/unittest/test_hybrid_script.py | 2 +- .../python/unittest/test_lang_constructor.py | 7 +-- tests/python/unittest/test_lang_operator.py | 2 +- topi/include/topi/detail/constant_utils.h | 10 ++--- topi/python/topi/util.py | 12 +++--- 52 files changed, 83 insertions(+), 289 deletions(-) diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index ab9a711d28d8..9d9f98e79695 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { CHECK(expr.defined()); if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } @@ -523,8 +521,6 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { *ptr = static_cast(op->value); } else if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index cbcb72a151e4..d47759222112 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -30,6 +30,7 @@ #include #include +#include #include "expr.h" #include "ir.h" @@ -82,21 +83,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) { } } -/*! - * \brief Get x as constant uint expression. - * \param x The expression - * \return the address to the int expression, - * return nullptr, if x is not UIntImm. - */ -inline const uint64_t* as_const_uint(const PrimExpr& x) { - if (!x.defined()) return nullptr; - if (const ir::UIntImmNode* op = x.as()) { - return &(op->value); - } else { - return nullptr; - } -} - /*! * \brief Check whether x is a constant integer expression. * \param x The input argument @@ -626,11 +612,11 @@ TVM_DECLARE_INTRIN_UNARY(atan); // Implementation details after this inline bool is_const(const PrimExpr& x) { - if (x.as() || x.as()) { + if (x.as()) { return true; } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; - if (val.as() || val.as()) { + if (val.as()) { return true; } } @@ -640,8 +626,6 @@ inline bool is_const(const PrimExpr& x) { inline bool is_positive_const(const PrimExpr& a) { if (const ir::IntImmNode* op = a.as()) { return op->value > 0; - } else if (const ir::UIntImmNode* op = a.as()) { - return op->value > 0; } else { return false; } @@ -658,14 +642,10 @@ inline bool is_negative_const(const PrimExpr& a) { inline bool is_const_int(const PrimExpr& x, int64_t value) { if (const auto* op = x.as()) { return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; if (const auto* opv = val.as()) { return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); } } return false; diff --git a/include/tvm/ir.h b/include/tvm/ir.h index c637d055928c..20ebd92fc423 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -39,23 +39,6 @@ namespace ir { using IntImmNode = tvm::IntImmNode; using VarNode = tvm::VarNode; -/*! \brief constant unsigned integer. */ -class UIntImmNode : public PrimExprNode { - public: - /*! \brief The constant value content. */ - uint64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static PrimExpr make(DataType t, uint64_t value); - - static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode); -}; - /*! \brief Floating point constants. */ class FloatImmNode : public PrimExprNode { public: diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 7d57564fd3df..37a1fe4bffb2 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -161,7 +161,6 @@ class ExprFunctor { virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args ...) { @@ -203,7 +202,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); - IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode); IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); return vtable; @@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const BroadcastNode* op) override; void VisitExpr_(const ShuffleNode* op) override; void VisitExpr_(const IntImmNode* op) override; - void VisitExpr_(const UIntImmNode* op) override; void VisitExpr_(const FloatImmNode* op) override; void VisitExpr_(const StringImmNode* op) override; }; @@ -372,7 +369,6 @@ class TVM_DLL ExprMutator : PrimExpr VisitExpr_(const BroadcastNode* op) override; PrimExpr VisitExpr_(const ShuffleNode* op) override; PrimExpr VisitExpr_(const IntImmNode* op) override; - PrimExpr VisitExpr_(const UIntImmNode* op) override; PrimExpr VisitExpr_(const FloatImmNode* op) override; PrimExpr VisitExpr_(const StringImmNode* op) override; }; diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 7f36914eb0a6..5067277d32a8 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None): workload = tuple([args_to_workload(a) for a in x]) elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)): workload = x - elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): + elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): workload = x.value elif x is None: workload = 0 @@ -344,7 +344,7 @@ def _count_flop(exp): if len(source) != 1: raise FlopCalculationError("Found multiple output in the source of reduce op") return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) - if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)): + if isinstance(exp, (expr.FloatImm, expr.IntImm)): return 0 if isinstance(exp, expr.Cast): return _count_flop(exp.value) diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 3026914aed20..54001d3338ad 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -155,9 +155,9 @@ def get_const_int(exp): """ if isinstance(exp, int): return exp - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): exp = ir_pass.Simplify(exp) - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -179,9 +179,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, expr.Var): ret.append(elem) - elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)): + elif not isinstance(elem, (expr.IntImm, int)): elem = ir_pass.Simplify(elem) - if not isinstance(elem, (expr.IntImm, expr.UIntImm)): + if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: ret.append(get_const_int(elem)) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 71c0aecd1f6a..2fd7b78d9d66 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -341,23 +341,6 @@ def __int__(self): return self.value -@register_object -class UIntImm(ConstExpr): - """UInt constant. - - Parameters - ---------- - dtype : str - The data type - - value : int - The constant value. - """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__( - _make.UIntImm, dtype, value) - - @register_object class StringImm(ConstExpr): """String constant. diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 1d5612e67e80..7038f6144db3 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -156,6 +156,6 @@ def max_num_threads(func_id, args): if args.__len__() == 0: res = _tgt.current_target().max_num_threads else: - _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") + _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") res = _tgt.current_target(args[0].value).max_num_threads return _api.convert(res) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 06bcbcabe0c3..57d636328816 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -386,7 +386,7 @@ def visit_Subscript(self, node): if isinstance(i, numbers.Integral): arr = arr[i] else: - _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ + _internal_assert(isinstance(i, (_expr.IntImm,)), \ "All indices are supposed to be constants") arr = arr[i.value] return arr @@ -413,7 +413,7 @@ def visit_If(self, node): cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) # Return no IfThenElse if proven - if isinstance(cond, _expr.UIntImm): + if isinstance(cond, _expr.IntImm): if cond.value: return visit_list_to_block(self.visit, node.body) if node.orelse: diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0dd1fa141329..a08a380dd767 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -33,7 +33,7 @@ #pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) +halide_imm_types = (_expr.IntImm, _expr.FloatImm) def _internal_assert(cond, err): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7e22d72131ac..e7f4682e7eb2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -931,7 +931,7 @@ def _shape(): def _impl(inputs, attr, params): is_symbolic_shape = False for axis in attr['_input_shapes'][inputs[0]]: - if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(axis, (int, tvm.expr.IntImm)): is_symbolic_shape = True break diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 049b6ee38d48..30ca51592c8f 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); -REGISTER_MAKE(UIntImm); REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 7a3baa678352..e03e5e2387bf 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { } bool Analyzer::CanProve(const PrimExpr& expr) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value != 0; } auto res = this->rewrite_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } res = this->canonical_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } return false; diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 2bee70ed557a..3b803ecd84a2 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -76,8 +76,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ using ir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ @@ -87,8 +85,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_INDEX_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ const DataType& ta = a.dtype(); \ @@ -268,8 +264,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); }); return PrimExpr(); } @@ -277,8 +273,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); }); return PrimExpr(); } @@ -286,8 +282,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); }); return PrimExpr(); } @@ -295,8 +291,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); }); return PrimExpr(); } @@ -304,8 +300,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); }); return PrimExpr(); } @@ -313,17 +309,16 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); }); return PrimExpr(); } template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; @@ -333,9 +328,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; @@ -345,10 +339,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); + const IntImmNode* pa = a.as(); if (pa) { - return UIntImmNode::make(DataType::UInt(1), !(pa->value)); + return IntImm(DataType::UInt(1), !(pa->value)); } return PrimExpr(); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 3a85c39aa3f0..25d88d3429b6 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -150,14 +150,6 @@ class ConstIntBoundAnalyzer::Impl : return MakeBound(op->value, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value <= static_cast(kPosInf)) { - return MakeBound(op->value, op->value); - } else { - return Everything(op->dtype); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 9b1ab3d63907..37d5e9eb5e57 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -384,10 +384,6 @@ class IntervalSetEvaluator : return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const UIntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); - } - IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = dom_map_.find(var); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 972c5148134f..c81842035c9f 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -132,14 +132,6 @@ class ModularSetAnalyzer::Impl : return Entry(0, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value < std::numeric_limits::max()) { - return Entry(0, static_cast(op->value)); - } else { - return Everything(); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index eae15248751b..906631368f74 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -386,10 +386,6 @@ inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeG } } -inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintUIntConst(op->dtype, op->value, os, p); -} - inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) switch (op->dtype.bits()) { case 64: case 32: { @@ -413,9 +409,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintConst(op, os, this); -} + void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index cb092c566322..7e5dd4269c94 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -128,7 +128,6 @@ class CodeGenC : void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 7967c1847ac2..cea276d5cb1a 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -247,11 +247,6 @@ void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) { CodeGenC::VisitExpr_(op, os); } -void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) { - CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; - CodeGenC::VisitExpr_(op, os); -} - void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CodeGenC::VisitExpr_(op, os); diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index cd1ec83360c6..19ca2ee12c6c 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -50,7 +50,6 @@ class CodeGenOpenGL final : public CodeGenC { // Codegen for immediate values void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 6879fd5f8542..44862cf7a97c 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -48,7 +48,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); + Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -68,8 +68,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { if (!call->dtype.is_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { Array vcnt_args; - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } @@ -93,16 +93,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { const CallNode* c0 = input8.as(); CHECK(c0 != nullptr); Array vcnt8_args; - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); PrimExpr vcnt8 = ir::CallNode::make( uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); PrimExpr vcnt16 = ir::CallNode::make( uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); @@ -112,8 +112,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 16->32bit Array vcnt32_args; - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); PrimExpr vcnt32 = ir::CallNode::make( uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); @@ -123,8 +123,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 32->64bit Array vcnt64_args; - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); return ir::CallNode::make( call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 20edd0a901a7..75982cc21848 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -662,15 +662,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); - const uint64_t *num_signature = as_const_uint(op->args[1]); - CHECK(num_signature) << "The second argument should be a uint represents number of arguments, " - << "but " << op->args[1] << " got!\n"; + Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector sig_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); - if (i - 2 < *num_signature) { + if (i - 2 < static_cast(num_signature)) { sig_type.push_back(arg_value.back()->getType()); } } @@ -810,10 +808,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) { - return llvm::ConstantInt::get(LLVMType(op->dtype), op->value); -} - llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(LLVMType(op->dtype), op->value); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 34c3ee723e18..b269f2423fc8 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -106,7 +106,6 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const VarNode* op) override; llvm::Value* VisitExpr_(const CastNode* op) override; llvm::Value* VisitExpr_(const IntImmNode* op) override; - llvm::Value* VisitExpr_(const UIntImmNode* op) override; llvm::Value* VisitExpr_(const FloatImmNode* op) override; llvm::Value* VisitExpr_(const StringImmNode* op) override; llvm::Value* VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index b3ab557ee215..1f839f362f40 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -43,8 +43,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); @@ -60,8 +60,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index ac7423e8ad87..8016444dad50 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -136,10 +136,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { return builder_->IntImm(builder_->GetSType(op->dtype), op->value); } -spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) { - return builder_->UIntImm(builder_->GetSType(op->dtype), op->value); -} - spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); } @@ -242,7 +238,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = op->args[0].as()->value; + uint32_t inst_id = static_cast( + op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 3804bda0f2e0..5aa7f9c49910 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -65,7 +65,6 @@ class CodeGenSPIRV: spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; spirv::Value VisitExpr_(const IntImmNode* op) override; - spirv::Value VisitExpr_(const UIntImmNode* op) override; spirv::Value VisitExpr_(const FloatImmNode* op) override; spirv::Value VisitExpr_(const StringImmNode* op) override; spirv::Value VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index d41d96db5165..d96883ed02fd 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -39,7 +39,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), id)); for (PrimExpr arg : call->args) { cargs.push_back(arg); diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index eccff6c74c2e..01096ae1dd46 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -280,12 +280,6 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) { - CHECK(op->value <= std::numeric_limits::max()) - << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); -} - void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { LOG(FATAL) << "Float Imm is not supported"; } diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 07989b2062e1..1360cc2d70f1 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -136,7 +136,6 @@ class CodeGenStackVM void VisitExpr_(const RampNode* op) final; void VisitExpr_(const BroadcastNode* op) final; void VisitExpr_(const IntImmNode* op) final; - void VisitExpr_(const UIntImmNode* op) final; void VisitExpr_(const FloatImmNode* op) final; void VisitExpr_(const StringImmNode* op) final; // statment diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7e3d44f26aef..346ec3808919 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -79,10 +79,7 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) os << op->value; } -void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "(" << op->value << ")"; -} + void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 89a1ece577f9..33bd0efae8a4 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -117,7 +117,6 @@ class CodeGenHybrid : void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 34ee4b3159a5..4fffc475a773 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -77,7 +77,6 @@ class AttrFunctor { virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. @@ -113,7 +112,6 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); - ATTR_FUNCTOR_DISPATCH(UIntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); @@ -157,7 +155,6 @@ class AttrsEqualHandler : bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final; @@ -198,7 +195,6 @@ class AttrsHashHandler : protected: size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImmNode* lhs) final; - size_t VisitAttr_(const ir::UIntImmNode* lhs) final; size_t VisitAttr_(const ir::FloatImmNode* lhs) final; size_t VisitAttr_(const ir::StringImmNode* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 1d3e767a5b71..a590f10e78e5 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -97,13 +97,6 @@ bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } - return false; -} - bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; @@ -224,10 +217,6 @@ size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) { - return std::hash()(op->value); -} - size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) { return std::hash()(op->value); } diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 6c7c54726eb9..5f9816f6898b 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -86,7 +86,6 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) } } - // maximum and min limits PrimExpr max_value(const DataType& dtype) { using namespace ir; @@ -101,11 +100,11 @@ PrimExpr max_value(const DataType& dtype) { } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { - return UIntImmNode::make(dtype, std::numeric_limits::max()); + return make_const(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { uint64_t val = 1; val = (val << static_cast(dtype.bits())) - 1; - return UIntImmNode::make(dtype, val); + return IntImm(dtype, static_cast(val)); } } else if (dtype.is_float()) { if (dtype.bits() == 64) { @@ -132,7 +131,7 @@ PrimExpr min_value(const DataType& dtype) { return IntImm(dtype, val); } } else if (dtype.is_uint()) { - return UIntImmNode::make(dtype, 0); + return IntImm(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImmNode::make(dtype, std::numeric_limits::lowest()); @@ -163,24 +162,18 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); - } else if (const auto* op = x.as()) { - return ConstPowerHelper(op->value, shift); } else { return false; } } PrimExpr cast(const DataType& t, PrimExpr value) { - using ir::IntImmNode; - using ir::UIntImmNode; using ir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } @@ -192,8 +185,6 @@ PrimExpr cast(const DataType& t, PrimExpr value) { if (value.dtype() != vtype) { if (const IntImmNode* op = value.as()) { value = make_const(vtype, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { @@ -330,18 +321,10 @@ PrimExpr max(PrimExpr a, PrimExpr b) { } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - using ir::IntImmNode; - using ir::UIntImmNode; CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); - if (const UIntImmNode* op = cond.as()) { - if (op->value != 0) { - return true_value; - } else { - return false_value; - } - } else if (const IntImmNode* op = cond.as()) { + if (const IntImmNode* op = cond.as()) { if (op->value != 0) { return true_value; } else { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 5a24e965e780..f06a6be5e75a 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -31,14 +31,6 @@ namespace tvm { namespace ir { // constructors -PrimExpr UIntImmNode::make(DataType t, uint64_t value) { - CHECK(t.is_uint() && t.lanes() == 1) - << "ValueError: UIntImm can only take scalar"; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return PrimExpr(node); -} PrimExpr FloatImmNode::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) @@ -531,11 +523,6 @@ Stmt EvaluateNode::make(PrimExpr value) { } // Printers -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(" << op->dtype << ")" << op->value; - }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -1153,7 +1140,6 @@ TVM_REGISTER_NODE_TYPE(AnyNode); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_REGISTER_NODE_TYPE(FloatImmNode); TVM_REGISTER_NODE_TYPE(IntImmNode); -TVM_REGISTER_NODE_TYPE(UIntImmNode); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(VarNode); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 612a56664c8c..0f350d2d732e 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -179,11 +179,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - UIntImmNode::make(DataType::UInt(8), dtype.code()) && + IntImm(DataType::UInt(8), dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - UIntImmNode::make(DataType::UInt(8), dtype.bits()) && + IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - UIntImmNode::make(DataType::UInt(16), dtype.lanes())); + IntImm(DataType::UInt(16), dtype.lanes())); asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 6eacb145b29b..8c441510c51d 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -252,10 +252,6 @@ class IRDeepCompare : CompareValue(op->value, other.as()->value); } - void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final { - CompareValue(op->value, other.as()->value); - } - void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 67acec674630..857206f8dd9f 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -260,7 +260,6 @@ DEFINE_BINOP_VISIT_(AndNode); DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImmNode* op) {} -void ExprVisitor::VisitExpr_(const UIntImmNode* op) {} void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} void ExprVisitor::VisitExpr_(const StringImmNode* op) {} @@ -640,7 +639,6 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 7b760fa4a672..5aba355b7003 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -180,9 +180,6 @@ class AttrScopeLifter : public StmtMutator { if (const IntImmNode* op = a.as()) { return op->value == b.as()->value; } - if (const UIntImmNode* op = a.as()) { - return op->value == b.as()->value; - } return false; } diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index a0b07c293b05..d509169df0b1 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -120,7 +120,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const CommReducerNode *combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const UIntImmNode *size_of_args = call->args[0].as(); + const IntImmNode *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 224a81c12396..9fb19cc4b308 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -96,7 +96,6 @@ class UnsafeExprDetector : public ExprFunctor { return false; } bool VisitExpr_(const VarNode* op) final { return false; } - bool VisitExpr_(const UIntImmNode* op) final { return false; } bool VisitExpr_(const IntImmNode* op) final { return false; } bool VisitExpr_(const FloatImmNode* op) final { return false; } bool VisitExpr_(const StringImmNode* op) final { return false; } diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index b2c50f7a8bd2..26ad59189671 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -159,14 +159,10 @@ class LoopUnroller : public StmtExprMutator { // constant folding. PrimExpr extent = ir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); - const UIntImmNode *v2 = extent.as(); int value = -1; if (v1 != nullptr) { value = static_cast(v1->value); } - if (v2 != nullptr) { - value = static_cast(v2->value); - } return value; } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 25650c7766cb..400a6bea22ed 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -857,10 +857,6 @@ class PrettyPrinter : return PrintConstScalar(op->dtype, &(op->value)); } - Doc VisitAttr_(const ir::UIntImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); - } - Doc VisitAttr_(const ir::FloatImmNode* op) final { return PrintConstScalar(op->dtype, &(op->value)); } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index d0d8b43f4c61..01280d209c0c 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -41,7 +41,7 @@ class TypeSolver::Reporter : public TypeReporterNode { } bool Assert(const IndexExpr& cond) final { - if (const uint64_t* pdiff = as_const_uint(cond)) { + if (const int64_t* pdiff = as_const_int(cond)) { return pdiff[0]; } return true; diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 378a5e3728f4..2e332413c1f6 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -47,14 +47,10 @@ static inline Array get_shape(const Type& type) { static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* min_value = as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); return static_cast(min_value[0]); - } else if (dtype.is_uint()) { - auto* min_value = as_const_uint(tvm::min_value(dtype)); - CHECK(min_value != nullptr); - return static_cast(min_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning @@ -64,14 +60,10 @@ static inline const int32_t GetQmin(const DataType& dtype) { static inline const int32_t GetQmax(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* max_value = as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); return static_cast(max_value[0]); - } else if (dtype.is_uint()) { - auto* max_value = as_const_uint(tvm::max_value(dtype)); - CHECK(max_value != nullptr); - return static_cast(max_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5392eaeac1e8..193f2f206c06 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -127,10 +127,10 @@ TEST(Pattern, Basic) { } } -TEST(Pattern, Integer) { +TEST(Pattern, IntImm) { using namespace tvm; tvm::Var tx, ty; - arith::PVar c; + arith::PVar c; arith::PVar v; { // We can match integer and Var, both of which are diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index c3c40cf740ad..5f1facb2b45f 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) + assert isinstance(val, (tvm.expr.IntImm,)) return val.value ctx = tvm.context(target, 0) diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index fe329494e24e..c4187858a8a8 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -38,16 +38,11 @@ def test_expr_constructor(): assert x.value == 2 assert x.dtype == "int64" - x = tvm.expr.UIntImm("uint16", 2) - assert isinstance(x, tvm.expr.UIntImm) - assert x.value == 2 - assert x.dtype == "uint16" - x = tvm.expr.StringImm("xyza") assert isinstance(x, tvm.expr.StringImm) assert x.value == "xyza" - x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1)) + x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1)) assert isinstance(x, tvm.expr.Cast) assert x.dtype == "float32" assert x.value.value == 1 diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index c57f4a1109ec..ac2ee6d88cc5 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -29,7 +29,7 @@ def test_const_fold(): def check(f, *args): x = f(*[tvm.const(x, "int32") for x in args]) y = f(*args) - if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): + if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y): raise ValueError("check error: %s vs %s " % (x, y)) tmod = tvm.truncmod diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 43ac3a29cd7c..e6de76f20881 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -43,8 +43,7 @@ using namespace tvm; */ inline bool IsConstInt(PrimExpr expr) { return - expr->IsInstance() || - expr->IsInstance(); + expr->IsInstance(); } /*! @@ -56,11 +55,8 @@ inline bool IsConstInt(PrimExpr expr) { * \return The integer value. */ inline int64_t GetConstInt(PrimExpr expr) { - if (expr->IsInstance()) { - return expr.as()->value; - } - if (expr->IsInstance()) { - return expr.as()->value; + if (expr->IsInstance()) { + return expr.as()->value; } LOG(ERROR) << "expr must be a constant integer"; return -1; diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 8f32a297d719..02d082b8b342 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -92,9 +92,9 @@ def get_const_int(expr): """ if isinstance(expr, Integral): return expr - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -136,9 +136,9 @@ def equal_const_int(expr, value): """ if isinstance(expr, Integral): return expr == value - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): return False return expr.value == value @@ -160,9 +160,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, tvm.expr.Var): ret.append(elem) - elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): + elif not isinstance(elem, (tvm.expr.IntImm, int)): elem = tvm.ir_pass.Simplify(elem) - if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(elem, tvm.expr.IntImm): ret.append(elem) else: ret.append(get_const_int(elem)) From 4087b5d4cbf78e22d027281c8a008eeac209713f Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 14 Jan 2020 14:51:25 -0800 Subject: [PATCH 3/3] rename big->large --- include/tvm/expr_operator.h | 6 +++--- include/tvm/ir.h | 4 ++-- python/tvm/api.py | 2 +- src/api/api_lang.cc | 4 ++-- src/codegen/codegen_c.cc | 2 +- src/codegen/llvm/codegen_llvm.cc | 2 +- src/codegen/spirv/codegen_spirv.cc | 2 +- src/lang/expr_operator.cc | 4 ++-- tests/python/unittest/test_codegen_device.py | 4 ++-- tests/python/unittest/test_codegen_llvm.py | 4 ++-- 10 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index d47759222112..ff3b340bf1fa 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -584,13 +584,13 @@ TVM_DLL PrimExpr nearbyint(PrimExpr x); TVM_DLL PrimExpr trunc(PrimExpr x); /*! - * \brief Construct a big uint constant by its low 32 bits and high 32bits. + * \brief Construct a large uint constant by its low 32 bits and high 32bits. * \param dtype The final data type. * \param low The lower 32 bits. * \param high The higher 32 bits. * \return The constructed expression. */ -TVM_DLL PrimExpr BigUIntImm(DataType dtype, int64_t low, int64_t high); +TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ @@ -674,7 +674,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { uint64_t mask = (static_cast(1) << 32U) - 1U; uint64_t low = uval & mask; uint64_t high = uval >> 32U; - return BigUIntImm(t, static_cast(low), static_cast(high)); + return LargeUIntImm(t, static_cast(low), static_cast(high)); } } if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 20ebd92fc423..9c14a31be2fe 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1410,11 +1410,11 @@ namespace intrinsic { * * Construct a big uint that may not be representable by int64 * - * Expr tvm_big_uint_imm(uint32_t v0, uin32_t v1) { + * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) { * return (v1 << 32) | v0; * } */ -constexpr const char* tvm_big_uint_imm = "tvm_big_uint_imm"; +constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm"; /*! * \brief See pesudo code * diff --git a/python/tvm/api.py b/python/tvm/api.py index 9afa0cc0609e..4bfe794c14d3 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -93,7 +93,7 @@ def const(value, dtype=None): if dtype is None: dtype = _scalar_type_inference(value) if dtype == "uint64" and value >= (1 << 63): - return _api_internal._BigUIntImm( + return _api_internal._LargeUIntImm( dtype, value & ((1 << 32) - 1), value >> 32) return _api_internal._const(value, dtype) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 6b0cfdd55bd6..fa7b59d36b88 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -53,8 +53,8 @@ TVM_REGISTER_GLOBAL("_const") } }); -TVM_REGISTER_GLOBAL("_BigUIntImm") -.set_body_typed(BigUIntImm); +TVM_REGISTER_GLOBAL("_LargeUIntImm") +.set_body_typed(LargeUIntImm); TVM_REGISTER_GLOBAL("_str") .set_body_typed(ir::StringImmNode::make); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 906631368f74..d9b7f7f08d12 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -527,7 +527,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << ")"; } else if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsic(op, " & ", os, this); - } else if (op->is_intrinsic(intrinsic::tvm_big_uint_imm)) { + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 75982cc21848..60d8146fc0e6 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -720,7 +720,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); - } else if (op->is_intrinsic(intrinsic::tvm_big_uint_imm)) { + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 8016444dad50..985f6816a640 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -282,7 +282,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->is_intrinsic(CallNode::reinterpret)) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); - } else if (op->is_intrinsic(intrinsic::tvm_big_uint_imm)) { + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 5f9816f6898b..bd43d89d89d0 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -35,9 +35,9 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return ir::CastNode::make(t, value); } -PrimExpr BigUIntImm(DataType t, int64_t low, int64_t high) { +PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { return ir::CallNode::make( - t, ir::intrinsic::tvm_big_uint_imm, + t, ir::intrinsic::tvm_large_uint_imm, {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, ir::CallNode::PureIntrinsic); diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 592b073767a8..5a10618fb269 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -18,7 +18,7 @@ from tvm.contrib import util import numpy as np -def test_big_uint_imm(): +def test_large_uint_imm(): value = (1 << 63) + 123 other = tvm.const(3, "uint64") n = 12 @@ -138,5 +138,5 @@ def check_module_save(device, host="stackvm"): if __name__ == "__main__": - test_big_uint_imm() + test_large_uint_imm() test_add_pipeline() diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index f21bc33dadd5..4920206ee019 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -88,7 +88,7 @@ def test_llvm_lookup_intrin(): fcode = tvm.build(func, None, "llvm") -def test_llvm_big_uintimm(): +def test_llvm_large_uintimm(): value = (1 << 63) + 123 other = tvm.const(3, "uint64") A = tvm.compute((), lambda : tvm.const(value, "uint64") + other, name='A') @@ -664,7 +664,7 @@ def vectorizer(op): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) if __name__ == "__main__": - test_llvm_big_uintimm() + test_llvm_large_uintimm() test_llvm_import() test_alignment() test_rank_zero()