From 4f9483bde15ab389376ef16f81c3d03d6de8f494 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 9 Jul 2020 11:44:50 +0800 Subject: [PATCH 01/21] Fix fuse tests --- include/tvm/topi/detail/constant_utils.h | 6 ++++-- src/relay/backend/compile_engine.cc | 7 +------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 03317c2c1dbb..201a0da94278 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -115,8 +115,10 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { tvm::tir::ExprDeepEqual expr_equal; bool result = expr_equal(lhs, rhs); if (!result) { - PrimExpr zero(0); - result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero); + PrimExpr t = tvm::arith::Analyzer().Simplify(lhs - rhs); + if (const IntImmNode* i = t.as()) { + result = i->value == 0; + } } return result; } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9ee57278e2f9..d13e849b802a 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -76,12 +76,7 @@ Array GetShape(const Array& shape) { // even if the result of shape inference becomes int64. Array res; for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { - CHECK_LE(pval[0], std::numeric_limits::max()); - CHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); - } else if (val->IsInstance()) { + if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { res.push_back(val); From 5a0c944f87300cff75f71ffd737997c607e90362 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 9 Jul 2020 16:06:56 +0800 Subject: [PATCH 02/21] Cast when extent=1 --- src/te/operation/op_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 61b782629d19..a05d61232b5e 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -112,7 +112,7 @@ std::vector > MakeLoopNest(const Stage& stage, } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { - nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op)); + nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op)); value_map[iv] = dom->min; } else if (is_zero(dom->min)) { nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); From 87c8d9fbe5732952c90d8717571073f08c11a52f Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 9 Jul 2020 16:29:14 +0800 Subject: [PATCH 03/21] Print --- src/tir/ir/expr.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 687dfd630f1d..07ee873b67d1 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -102,7 +102,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* op = static_cast(node.get()); // omit the type // stream << op->name << "." << op->type; - p->stream << op->name_hint; + // p->stream << op->name_hint; + p->stream << op->name_hint << "." << op->dtype; }); // SizeVar From 1b9c3002dd4d5c70dd12c262a587cb4be467400d Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 9 Jul 2020 16:30:40 +0800 Subject: [PATCH 04/21] Print --- src/tir/ir/expr.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 07ee873b67d1..2a8661f22f34 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -145,9 +145,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "iter_var("; - if (op->var->name_hint.length() != 0) { - p->stream << op->var->name_hint << ", "; - } + p->stream << op->var << ", "; + // if (op->var->name_hint.length() != 0) { + // p->stream << op->var->name_hint << ", "; + // } if (op->dom.defined()) { p->stream << op->dom; } From 328a917d60974116c27ecb1898cebbacc1c0ea41 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 9 Jul 2020 16:35:16 +0800 Subject: [PATCH 05/21] Cast when extent=1 --- src/te/operation/op_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index a05d61232b5e..db649e541a65 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -113,7 +113,7 @@ std::vector > MakeLoopNest(const Stage& stage, } if (!debug_keep_trivial_loop && is_one(dom->extent)) { nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op)); - value_map[iv] = dom->min; + value_map[iv] = cast(var.dtype(), dom->min); } else if (is_zero(dom->min)) { nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); value_map[iv] = var; From a387094e08eebd65d7b1e8eb243c0172219a3584 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 9 Jul 2020 17:12:31 +0800 Subject: [PATCH 06/21] Print relay func --- src/relay/backend/compile_engine.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d13e849b802a..052c8bc87f67 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -126,7 +126,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> candidate_name = truncated_name.str(); } cache_node->func_name = candidate_name; - + std::cout << "pf " << candidate_name << ": " << std::endl << PrettyPrint(prim_func) << std::endl; CHECK(master_op_.defined()); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. From 29e7feee5495ea07c986fd0c072a59628e042a1c Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 9 Jul 2020 17:22:22 +0800 Subject: [PATCH 07/21] Fix int_set when unbounded --- src/arith/int_set.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 03645b42e0dc..e26e96af32ea 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -289,8 +289,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int if (analyzer->CanProveGreaterEqual(divisor, 0)) { if (divisor.as()) { // a mod b = a - (a / b) * b if a_max / b == a_min / b - auto qmax = floordiv(a->max_value, divisor); - auto qmin = floordiv(a->min_value, divisor); + auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf(); + auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); if (analyzer->CanProve(qmax == qmin)) { auto tmax = a->max_value - divisor * qmin; auto tmin = a->min_value - divisor * qmin; From 039f91e29c0c80bd1c645a4bd5ae76580ff7ea6a Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 10 Jul 2020 15:37:10 +0800 Subject: [PATCH 08/21] Fix mean --- src/relay/op/tensor/reduce.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index e16ecb6d7119..16f5f0116b60 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -506,6 +506,9 @@ Array MeanCompute(const Attrs& attrs, const Array& input for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) { count *= inputs[0]->shape[i]; } + // Although count is created as inputs[0]->dtype, + // its type may be changed (promoted) during multiplication + count = cast(inputs[0]->dtype, count); auto res = ReduceCompute(attrs, inputs, out_type, topi::sum); return {topi::divide(res[0], count)}; } From 77e379cb7d44c3d9e8bbec497968d8e1d2e2cd64 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Sat, 11 Jul 2020 00:55:16 +0800 Subject: [PATCH 09/21] Fix int_set when unbounded --- src/arith/int_set.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index e26e96af32ea..8d2dd2d4287a 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -609,6 +609,7 @@ bool IntSet::MatchRange(const Range& b) const { const IntSet& a = *this; const IntervalSetNode* a_int = a.as(); if (!a_int) return false; + if (!a_int->HasUpperBound() || !a_int->HasLowerBound()) return false; Analyzer ana; return ProveEqual(&ana, a_int->min_value, b->min) && ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); From 05832dcda1c252fbe59a1a59cf97489b38e66c83 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 27 Jul 2020 20:43:32 +0800 Subject: [PATCH 10/21] impl CastNode for int_set --- src/arith/int_set.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 8d2dd2d4287a..7b0f62ae9e87 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -441,6 +441,15 @@ class IntervalSetEvaluator : public ExprFunctor { return Union(analyzer_, false_set, true_set); } + IntervalSet VisitExpr_(const CastNode* op) final { + IntervalSet value_set = this->Eval(op->value); + PrimExpr min_value = value_set->HasLowerBound() ? + cast(op->dtype, value_set->min_value) : neg_inf(); + PrimExpr max_value = value_set->HasUpperBound() ? + cast(op->dtype, value_set->max_value) : pos_inf(); + return IntervalSet(min_value, max_value); + } + IntervalSet VisitExprDefault_(const Object* op) final { DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); From 8bf52307c11ee879e8b0dbf5791ab4273d31778e Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 27 Jul 2020 21:16:29 +0800 Subject: [PATCH 11/21] lint --- src/arith/int_set.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 7b0f62ae9e87..25f5ee7c8434 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -443,10 +443,10 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const CastNode* op) final { IntervalSet value_set = this->Eval(op->value); - PrimExpr min_value = value_set->HasLowerBound() ? - cast(op->dtype, value_set->min_value) : neg_inf(); - PrimExpr max_value = value_set->HasUpperBound() ? - cast(op->dtype, value_set->max_value) : pos_inf(); + PrimExpr min_value = + value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); + PrimExpr max_value = + value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); return IntervalSet(min_value, max_value); } From 9a1d3f5c6e17770321f4e57108467a60c40c02b9 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 27 Jul 2020 20:52:15 +0800 Subject: [PATCH 12/21] remove debug outputs --- src/relay/backend/compile_engine.cc | 1 - src/tir/ir/expr.cc | 10 ++++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 052c8bc87f67..4177cd71e532 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -126,7 +126,6 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> candidate_name = truncated_name.str(); } cache_node->func_name = candidate_name; - std::cout << "pf " << candidate_name << ": " << std::endl << PrettyPrint(prim_func) << std::endl; CHECK(master_op_.defined()); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 2a8661f22f34..687dfd630f1d 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -102,8 +102,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* op = static_cast(node.get()); // omit the type // stream << op->name << "." << op->type; - // p->stream << op->name_hint; - p->stream << op->name_hint << "." << op->dtype; + p->stream << op->name_hint; }); // SizeVar @@ -145,10 +144,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "iter_var("; - p->stream << op->var << ", "; - // if (op->var->name_hint.length() != 0) { - // p->stream << op->var->name_hint << ", "; - // } + if (op->var->name_hint.length() != 0) { + p->stream << op->var->name_hint << ", "; + } if (op->dom.defined()) { p->stream << op->dom; } From 8898ef0b34820173ccac07e4ebf53ec68ac2649b Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 27 Jul 2020 21:20:51 +0800 Subject: [PATCH 13/21] lint --- src/arith/int_set.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 25f5ee7c8434..9940d1f60b39 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -444,9 +444,9 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const CastNode* op) final { IntervalSet value_set = this->Eval(op->value); PrimExpr min_value = - value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); + value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); PrimExpr max_value = - value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); + value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); return IntervalSet(min_value, max_value); } From 7b409bd9e3c2bfcebebc467da88805df403f5cbd Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 28 Jul 2020 11:23:17 +0800 Subject: [PATCH 14/21] fix bound deducer --- src/arith/bound_deducer.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index f68517032116..dc62826f3d85 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -93,19 +93,19 @@ class BoundDeducer : public ExprVisitor { } void VisitExpr_(const LTNode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; + success_ = false; } void VisitExpr_(const LENode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; + success_ = false; } void VisitExpr_(const GTNode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; + success_ = false; } void VisitExpr_(const GENode* op) final { - LOG(FATAL) << "unable to deduce due to multiple comparison operator"; + success_ = false; } void VisitExpr_(const AddNode* op) final { From 197b80b4359312082958bcf14841b739c6ea622a Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 28 Jul 2020 11:32:50 +0800 Subject: [PATCH 15/21] lint --- src/arith/bound_deducer.cc | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index dc62826f3d85..9275ec1bc394 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -92,21 +92,13 @@ class BoundDeducer : public ExprVisitor { } } - void VisitExpr_(const LTNode* op) final { - success_ = false; - } + void VisitExpr_(const LTNode* op) final { success_ = false; } - void VisitExpr_(const LENode* op) final { - success_ = false; - } + void VisitExpr_(const LENode* op) final { success_ = false; } - void VisitExpr_(const GTNode* op) final { - success_ = false; - } + void VisitExpr_(const GTNode* op) final { success_ = false; } - void VisitExpr_(const GENode* op) final { - success_ = false; - } + void VisitExpr_(const GENode* op) final { success_ = false; } void VisitExpr_(const AddNode* op) final { bool left = op->a.get() == path_[iter_]; From 1ba25b79f8bf451122061964638ea8765c73040c Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 29 Jul 2020 12:14:26 +0800 Subject: [PATCH 16/21] fix fuse --- src/te/schedule/message_passing.cc | 4 ++++ src/te/schedule/schedule_lang.cc | 13 ++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 4313be85ee85..0a82673aa4b8 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -207,6 +207,10 @@ void PassUpIndex(const Stage& stage, const Map& dom_map, if (!is_zero(inner_min)) { state[s->inner] = state[s->inner] + inner_min; } + // s->fused, s->outer and s->inner may be of different dtype, + // so we cast the `state` back to its original dtype + state[s->outer] = cast(s->outer->var.dtype(), state[s->outer]); + state[s->inner] = cast(s->inner->var.dtype(), state[s->inner]); } else if (const RebaseNode* s = rel.as()) { if (!state.count(s->rebased)) { CHECK(allow_missing); diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 6473c7e3cf67..d6327ffe0f08 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -55,6 +55,16 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) return 0; } +DataType MatchDataType(std::vector dtypes) { + int max_bits = -1; + for (const auto& dtype : dtypes) { + CHECK(dtype.is_int()); + CHECK(dtype.is_scalar()); + max_bits = std::max(max_bits, dtype.bits()); + } + return DataType::Int(max_bits); +} + void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. @@ -228,8 +238,9 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT IterVarType iter_type = outer->iter_type; if (inner->iter_type > iter_type) iter_type = inner->iter_type; std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused"; + DataType iter_dtype = MatchDataType({inner->var.dtype(), outer->var.dtype()}); - IterVar fused = IterVar(Range(), Var(fused_name, outer->var.dtype()), iter_type); + IterVar fused = IterVar(Range(), Var(fused_name, iter_dtype), iter_type); Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; From 3f5c0d990970753e66fa0c6ec969acebce064dd0 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 30 Jul 2020 11:39:05 +0800 Subject: [PATCH 17/21] add tests --- python/tvm/relay/backend/compile_engine.py | 6 +-- .../test_tir_transform_narrow_datatype.py | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index f60335a4d44b..98d96968884d 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -79,11 +79,7 @@ def get_shape(shape): """Convert the shape to correct dtype and vars.""" ret = [] for dim in shape: - if isinstance(dim, tvm.tir.IntImm): - val = int(dim) - assert val <= np.iinfo(np.int32).max - ret.append(tvm.tir.IntImm("int32", val)) - elif isinstance(dim, tvm.tir.Any): + if isinstance(dim, tvm.tir.Any): ret.append(te.var("any_dim", "int32")) else: ret.append(dim) diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 6179bbbfbd07..657149967923 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm import relay from tvm.tir import const @@ -38,6 +39,7 @@ def lower_sch(sch, args, target_bits): arg_list.append(buf) else: raise ValueError("args must be Tensor, Buffer or Var") + sch = sch.normalize() bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) @@ -189,9 +191,56 @@ def check(m, n, target_bits, target_dtype): target_bits=32, target_dtype='int64') +def test_relay_basic(): + engine = relay.backend.compile_engine.get() + def check(shapex, shapey, target_bits, target_dtype): + x = relay.var('x', shape=shapex) + y = relay.var('y', shape=shapey) + z = relay.add(x, y) + func = relay.Function([x, y], z) + mod = tvm.IRModule.from_expr(func) + func = mod["main"] + z = engine.lower(func, "llvm") + stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + # outer loop + assert stmt.loop_var.dtype == target_dtype + # inner loop + if len(shapex) > 1 or len(shapey) > 1: + assert stmt.body.loop_var.dtype == target_dtype + + check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), (1, const(2**15 + 1, 'int64')), + target_bits=32, target_dtype="int64") + check((const(2**16, 'int64'), const(2**15, 'int64')), (1, const(2**15, 'int64')), + target_bits=32, target_dtype="int32") + check((const(2**31, 'int64'),), (const(2**31, 'int64'),), + target_bits=32, target_dtype="int32") + check((const(2**31 + 1, 'int64'),), (const(2**31 + 1, 'int64'),), + target_bits=32, target_dtype="int64") + + +def test_relay_take(): + engine = relay.backend.compile_engine.get() + def check(shape, index, target_bits, target_dtype): + x = relay.var("x", shape=shape) + y = relay.op.take(x, indices=index) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + func = mod["main"] + z = engine.lower(func, "llvm") + stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + assert stmt.value.index.dtype == target_dtype + + check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(0, dtype="int64"), + target_bits=32, target_dtype="int32") + check((const(2**16, 'int64'), const(2**15 + 1, 'int64')), relay.const(2**31, dtype="int64"), + target_bits=32, target_dtype="int64") + + if __name__ == "__main__": test_basic() test_thread_axis() test_multilanes() test_reduce() test_slice() + test_relay_basic() + test_relay_take() From 83f3702eb37b4c38566e9f88a44d87aee7f70707 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 30 Jul 2020 16:11:27 +0800 Subject: [PATCH 18/21] lint --- python/tvm/relay/backend/compile_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 98d96968884d..75fca169325e 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -19,7 +19,6 @@ from __future__ import absolute_import import logging -import numpy as np import tvm from tvm import te from tvm.runtime import Object From 963b6f250de54501219fc9f07abd96d7ada3b020 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 31 Aug 2020 10:40:51 +0800 Subject: [PATCH 19/21] add a flag --- CMakeLists.txt | 5 +++++ cmake/modules/LibInfo.cmake | 1 + python/tvm/relay/backend/compile_engine.py | 11 ++++++++++- src/relay/backend/compile_engine.cc | 11 ++++++++++- src/support/libinfo.cc | 5 +++++ 5 files changed, 31 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 80581076b925..b4ea8ba495f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,7 @@ tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF) +tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -259,6 +260,10 @@ if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) endif() +if (INDEX_DEFAULT_I64) + add_definitions(-DTVM_INDEX_DEFAULT_I64=1) +endif() + list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc) if(USE_RPC) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 7f9cea8c34d9..e7685f3a5447 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -73,6 +73,7 @@ function(add_lib_info src_file) TVM_INFO_USE_TARGET_ONNX="${USE_TARGET_ONNX}" TVM_INFO_USE_ARM_COMPUTE_LIB="${USE_ARM_COMPUTE_LIB}" TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME="${USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}" + TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}" ) endfunction() diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 75fca169325e..b34979d13a34 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -20,8 +20,10 @@ import logging import tvm +import numpy as np from tvm import te from tvm.runtime import Object +from tvm.support import libinfo from ... import target as _target from ... import autotvm from .. import function as _function @@ -78,7 +80,14 @@ def get_shape(shape): """Convert the shape to correct dtype and vars.""" ret = [] for dim in shape: - if isinstance(dim, tvm.tir.Any): + if isinstance(dim, tvm.tir.IntImm): + if (libinfo()["INDEX_DEFAULT_I64"] == "ON"): + ret.append(dim) + else: + val = int(dim) + assert val <= np.iinfo(np.int32).max + ret.append(tvm.tir.IntImm("int32", val)) + elif isinstance(dim, tvm.tir.Any): ret.append(te.var("any_dim", "int32")) else: ret.append(dim) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 4177cd71e532..a083c3b83b12 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -76,7 +76,16 @@ Array GetShape(const Array& shape) { // even if the result of shape inference becomes int64. Array res; for (IndexExpr val : shape) { - if (val->IsInstance()) { + const int64_t* pval = tir::as_const_int(val); + if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 + CHECK_LE(pval[0], std::numeric_limits::max()); + CHECK_GE(pval[0], std::numeric_limits::min()); + res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 + } else if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { res.push_back(val); diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 94bc9d1596b5..1a7b91dba18e 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -200,6 +200,10 @@ #define TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "NOT-FOUND" #endif +#ifndef TVM_INFO_INDEX_DEFAULT_I64 +#define TVM_INFO_INDEX_DEFAULT_I64 "NOT-FOUND" +#endif + namespace tvm { /*! @@ -253,6 +257,7 @@ TVM_DLL Map GetLibInfo() { {"USE_TARGET_ONNX", TVM_INFO_USE_TARGET_ONNX}, {"USE_ARM_COMPUTE_LIB", TVM_INFO_USE_ARM_COMPUTE_LIB}, {"USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME", TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}, + {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64} }; return result; } From e87cc4778e1a627f5e00cf8a698847d945d99b3d Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 31 Aug 2020 14:02:17 +0800 Subject: [PATCH 20/21] fix --- python/tvm/relay/backend/compile_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index b34979d13a34..d6508a6f61b7 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -19,8 +19,8 @@ from __future__ import absolute_import import logging -import tvm import numpy as np +import tvm from tvm import te from tvm.runtime import Object from tvm.support import libinfo @@ -81,7 +81,7 @@ def get_shape(shape): ret = [] for dim in shape: if isinstance(dim, tvm.tir.IntImm): - if (libinfo()["INDEX_DEFAULT_I64"] == "ON"): + if libinfo()["INDEX_DEFAULT_I64"] == "ON": ret.append(dim) else: val = int(dim) From 6ae5d1ec1c7264edb0e5a92a36464e76685fe510 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 31 Aug 2020 14:13:24 +0800 Subject: [PATCH 21/21] fix --- src/support/libinfo.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 1a7b91dba18e..0bdae82a31e0 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -257,8 +257,7 @@ TVM_DLL Map GetLibInfo() { {"USE_TARGET_ONNX", TVM_INFO_USE_TARGET_ONNX}, {"USE_ARM_COMPUTE_LIB", TVM_INFO_USE_ARM_COMPUTE_LIB}, {"USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME", TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME}, - {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64} - }; + {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64}}; return result; }