diff --git a/CMakeLists.txt b/CMakeLists.txt index 80581076b925..b4ea8ba495f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,7 @@ tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF) +tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -259,6 +260,10 @@ if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) endif() +if (INDEX_DEFAULT_I64) + add_definitions(-DTVM_INDEX_DEFAULT_I64=1) +endif() + list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc) if(USE_RPC) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 7f9cea8c34d9..e7685f3a5447 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -73,6 +73,7 @@ function(add_lib_info src_file) TVM_INFO_USE_TARGET_ONNX="${USE_TARGET_ONNX}" TVM_INFO_USE_ARM_COMPUTE_LIB="${USE_ARM_COMPUTE_LIB}" TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME="${USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}" + TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}" ) endfunction() diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 03317c2c1dbb..201a0da94278 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -115,8 +115,10 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { tvm::tir::ExprDeepEqual expr_equal; bool result = expr_equal(lhs, rhs); if (!result) { - PrimExpr zero(0); - result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero); + PrimExpr t = tvm::arith::Analyzer().Simplify(lhs - rhs); + if (const IntImmNode* i = t.as()) { + result = i->value == 0; + } } return result; } diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index f60335a4d44b..d6508a6f61b7 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -23,6 +23,7 @@ import tvm from tvm import te from tvm.runtime import Object +from tvm.support import libinfo from ... import target as _target from ... import autotvm from .. import function as _function @@ -80,9 +81,12 @@ def get_shape(shape): ret = [] for dim in shape: if isinstance(dim, tvm.tir.IntImm): - val = int(dim) - assert val <= np.iinfo(np.int32).max - ret.append(tvm.tir.IntImm("int32", val)) + if libinfo()["INDEX_DEFAULT_I64"] == "ON": + ret.append(dim) + else: + val = int(dim) + assert val <= np.iinfo(np.int32).max + ret.append(tvm.tir.IntImm("int32", val)) elif isinstance(dim, tvm.tir.Any): ret.append(te.var("any_dim", "int32")) else: diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index f68517032116..9275ec1bc394 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -92,21 +92,13 @@ class BoundDeducer : public ExprVisitor { } } - void VisitExpr_(const LTNode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; - } + void VisitExpr_(const LTNode* op) final { success_ = false; } - void VisitExpr_(const LENode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; - } + void VisitExpr_(const LENode* op) final { success_ = false; } - void VisitExpr_(const GTNode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; - } + void VisitExpr_(const GTNode* op) final { success_ = false; } - void VisitExpr_(const GENode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; - } + void VisitExpr_(const GENode* op) final { success_ = false; } void VisitExpr_(const AddNode* op) final { bool left = op->a.get() == path_[iter_]; diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 03645b42e0dc..9940d1f60b39 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -289,8 +289,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int if (analyzer->CanProveGreaterEqual(divisor, 0)) { if (divisor.as()) { // a mod b = a - (a / b) * b if a_max / b == a_min / b - auto qmax = floordiv(a->max_value, divisor); - auto qmin = floordiv(a->min_value, divisor); + auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf(); + auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); if (analyzer->CanProve(qmax == qmin)) { auto tmax = a->max_value - divisor * qmin; auto tmin = a->min_value - divisor * qmin; @@ -441,6 +441,15 @@ class IntervalSetEvaluator : public ExprFunctor { return Union(analyzer_, false_set, true_set); } + IntervalSet VisitExpr_(const CastNode* op) final { + IntervalSet value_set = this->Eval(op->value); + PrimExpr min_value = + value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); + PrimExpr max_value = + value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); + return IntervalSet(min_value, max_value); + } + IntervalSet VisitExprDefault_(const Object* op) final { DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); @@ -609,6 +618,7 @@ bool IntSet::MatchRange(const Range& b) const { const IntSet& a = *this; const IntervalSetNode* a_int = a.as(); if (!a_int) return false; + if (!a_int->HasUpperBound() || !a_int->HasLowerBound()) return false; Analyzer ana; return ProveEqual(&ana, a_int->min_value, b->min) && ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9ee57278e2f9..a083c3b83b12 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -78,9 +78,13 @@ Array GetShape(const Array& shape) { for (IndexExpr val : shape) { const int64_t* pval = tir::as_const_int(val); if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 CHECK_LE(pval[0], std::numeric_limits::max()); CHECK_GE(pval[0], std::numeric_limits::min()); res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 } else if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { @@ -131,7 +135,6 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> candidate_name = truncated_name.str(); } cache_node->func_name = candidate_name; - CHECK(master_op_.defined()); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index e16ecb6d7119..16f5f0116b60 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -506,6 +506,9 @@ Array MeanCompute(const Attrs& attrs, const Array& input for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) { count *= inputs[0]->shape[i]; } + // Although count is created as inputs[0]->dtype, + // its type may be changed (promoted) during multiplication + count = cast(inputs[0]->dtype, count); auto res = ReduceCompute(attrs, inputs, out_type, topi::sum); return {topi::divide(res[0], count)}; } diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 94bc9d1596b5..0bdae82a31e0 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -200,6 +200,10 @@ #define TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "NOT-FOUND" #endif +#ifndef TVM_INFO_INDEX_DEFAULT_I64 +#define TVM_INFO_INDEX_DEFAULT_I64 "NOT-FOUND" +#endif + namespace tvm { /*! @@ -253,7 +257,7 @@ TVM_DLL Map GetLibInfo() { {"USE_TARGET_ONNX", TVM_INFO_USE_TARGET_ONNX}, {"USE_ARM_COMPUTE_LIB", TVM_INFO_USE_ARM_COMPUTE_LIB}, {"USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME", TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}, - }; + {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64}}; return result; } diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 61b782629d19..db649e541a65 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -112,8 +112,8 @@ std::vector > MakeLoopNest(const Stage& stage, } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { - nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op)); - value_map[iv] = dom->min; + nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op)); + value_map[iv] = cast(var.dtype(), dom->min); } else if (is_zero(dom->min)) { nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); value_map[iv] = var; diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 4313be85ee85..0a82673aa4b8 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -207,6 +207,10 @@ void PassUpIndex(const Stage& stage, const Map& dom_map, if (!is_zero(inner_min)) { state[s->inner] = state[s->inner] + inner_min; } + // s->fused, s->outer and s->inner may be of different dtype, + // so we cast the `state` back to its original dtype + state[s->outer] = cast(s->outer->var.dtype(), state[s->outer]); + state[s->inner] = cast(s->inner->var.dtype(), state[s->inner]); } else if (const RebaseNode* s = rel.as()) { if (!state.count(s->rebased)) { CHECK(allow_missing); diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 6473c7e3cf67..d6327ffe0f08 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -55,6 +55,16 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) return 0; } +DataType MatchDataType(std::vector dtypes) { + int max_bits = -1; + for (const auto& dtype : dtypes) { + CHECK(dtype.is_int()); + CHECK(dtype.is_scalar()); + max_bits = std::max(max_bits, dtype.bits()); + } + return DataType::Int(max_bits); +} + void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. @@ -228,8 +238,9 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT IterVarType iter_type = outer->iter_type; if (inner->iter_type > iter_type) iter_type = inner->iter_type; std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused"; + DataType iter_dtype = MatchDataType({inner->var.dtype(), outer->var.dtype()}); - IterVar fused = IterVar(Range(), Var(fused_name, outer->var.dtype()), iter_type); + IterVar fused = IterVar(Range(), Var(fused_name, iter_dtype), iter_type); Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 6179bbbfbd07..657149967923 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm import relay from tvm.tir import const @@ -38,6 +39,7 @@ def lower_sch(sch, args, target_bits): arg_list.append(buf) else: raise ValueError("args must be Tensor, Buffer or Var") + sch = sch.normalize() bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) @@ -189,9 +191,56 @@ def check(m, n, target_bits, target_dtype): target_bits=32, target_dtype='int64') +def test_relay_basic(): + engine = relay.backend.compile_engine.get() + def check(shapex, shapey, target_bits, target_dtype): + x = relay.var('x', shape=shapex) + y = relay.var('y', shape=shapey) + z = relay.add(x, y) + func = relay.Function([x, y], z) + mod = tvm.IRModule.from_expr(func) + func = mod["main"] + z = engine.lower(func, "llvm") + stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + # outer loop + assert stmt.loop_var.dtype == target_dtype + # inner loop + if len(shapex) > 1 or len(shapey) > 1: + assert stmt.body.loop_var.dtype == target_dtype + + check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), (1, const(2**15 + 1, 'int64')), + target_bits=32, target_dtype="int64") + check((const(2**16, 'int64'), const(2**15, 'int64')), (1, const(2**15, 'int64')), + target_bits=32, target_dtype="int32") + check((const(2**31, 'int64'),), (const(2**31, 'int64'),), + target_bits=32, target_dtype="int32") + check((const(2**31 + 1, 'int64'),), (const(2**31 + 1, 'int64'),), + target_bits=32, target_dtype="int64") + + +def test_relay_take(): + engine = relay.backend.compile_engine.get() + def check(shape, index, target_bits, target_dtype): + x = relay.var("x", shape=shape) + y = relay.op.take(x, indices=index) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + func = mod["main"] + z = engine.lower(func, "llvm") + stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + assert stmt.value.index.dtype == target_dtype + + check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(0, dtype="int64"), + target_bits=32, target_dtype="int32") + check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(2**31, dtype="int64"), + target_bits=32, target_dtype="int64") + + if __name__ == "__main__": test_basic() test_thread_axis() test_multilanes() test_reduce() test_slice() + test_relay_basic() + test_relay_take()