diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index 7bb938ca7517..2c6bec4b2e1b 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -73,6 +73,7 @@ tvm.ir_pass tvm.ir_pass.SplitPipeline tvm.ir_pass.LowerThreadAllreduce tvm.ir_pass.LowerIntrin + tvm.ir_pass.RemoveIntrin tvm.ir_pass.LowerTVMBuiltin tvm.ir_pass.NarrowChannelAccess diff --git a/docs/api/python/tvm.rst b/docs/api/python/tvm.rst index b517195db9e4..2636a7b2b945 100644 --- a/docs/api/python/tvm.rst +++ b/docs/api/python/tvm.rst @@ -45,6 +45,7 @@ The user facing API for computation declaration. tvm.min tvm.max tvm.tag_scope + tvm.assert_bound .. autofunction:: tvm.load_json .. autofunction:: tvm.save_json @@ -70,3 +71,4 @@ The user facing API for computation declaration. .. autofunction:: tvm.min .. autofunction:: tvm.max .. autofunction:: tvm.tag_scope +.. autofunction:: tvm.assert_bound diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 41e7aa5b7796..e711364997e0 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -597,6 +597,16 @@ TVM_DLL Expr nearbyint(Expr x); */ TVM_DLL Expr trunc(Expr x); +/*! + * \brief Pass bound information of value. + * \param value The input expression. + * \param lower The lower bound of value (inclusive). + * \param upper The upper bound of value (inclusive). + * \return The Call node indicates lower and upper bound of input expression. + * This intrinsic will be removed before codegen. + */ +TVM_DLL Expr assert_bound(Expr value, Expr lower, Expr upper); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index c55a4695de4d..00f6debcbb70 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1613,6 +1613,16 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; */ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; +/*! + * \brief tvm intrinsic for passing bound information of the variables. + * It simply represents the value, while it helps BoundAnalyzer + * understand the upper and lower bound of the value. + * Expr tvm_assert_bound(Expr value, Expr lower_bound, Expr upper_bound) { + * return value; + * } + */ +constexpr const char* tvm_assert_bound = "tvm_assert_bound"; + } // namespace intrinsic /*! diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 6e1fed5a8542..1202fc2271ae 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -529,6 +529,13 @@ LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func); */ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); +/*! + * \brief Remove intrinsic function calls if possible. + * \param f The function to be processed. + * \return Transformed function. + */ +LoweredFunc RemoveIntrin(LoweredFunc f); + /*! * \brief Lower custom datatypes. * diff --git a/python/tvm/api.py b/python/tvm/api.py index ef121bc880b2..bea37bdd3d0c 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -257,6 +257,7 @@ def placeholder(shape, dtype=None, name="placeholder"): The created tensor """ shape = (shape,) if isinstance(shape, _expr.Expr) else shape + shape = tuple(assert_bound(size, 0, None) for size in shape) dtype = float32 if dtype is None else dtype return _api_internal._Placeholder( shape, dtype, name) @@ -296,6 +297,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): shape = (shape,) if isinstance(shape, _expr.Expr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) + shape = tuple(assert_bound(size, 0, None) for size in shape) ndim = len(shape) code = fcompute.__code__ @@ -1047,6 +1049,27 @@ def floormod(a, b): return _make._OpFloorMod(a, b) +def assert_bound(value, lower=None, upper=None): + """Pass bound information of value. + + Parameters + ---------- + value : Expr + The input expression. + lower : Expr + The lower bound of value (inclusive). Default +inf + upper : Expr + The upper bound of value (inclusive). Default -inf + + Returns + ------- + res : Expr + Call node indicates lower and upper bound of input expression. + This intrinsic will be removed before codegen. + """ + return _make._OpAssertBound(value, lower, upper) + + _init_api("tvm.api") #pylint: disable=unnecessary-lambda diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index f96e28323595..438d33403807 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -292,9 +292,16 @@ def get_binds(args, compact=False, binds=None): binds = {} if binds is None else binds.copy() cfg = current_build_config() arg_list = [] + + def is_var(idx): + if isinstance(idx, expr.Var) or \ + (isinstance(idx, expr.Call) and idx.name == "tvm_assert_bound"): + return True + return False + for x in args: if isinstance(x, tensor.Tensor): - any_dim = any(isinstance(i, expr.Var) for i in x.shape) + any_dim = any(is_var(i) for i in x.shape) buffer_type = "auto_broadcast" if any_dim and not compact else "" if x not in binds: buf = api.decl_buffer(x.shape, @@ -499,7 +506,9 @@ def _build_for_device(flist, target, target_host): fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] + fdevice = [ir_pass.RemoveIntrin(x) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] + fhost = [ir_pass.RemoveIntrin(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] mdev = codegen.build_module(fdevice, str(target)) if fdevice else None diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 414822068e07..012da179b36e 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -159,3 +159,14 @@ def max_num_threads(func_id, args): _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") res = _tgt.current_target(args[0].value).max_num_threads return _api.convert(res) + +def tvm_assert_bound(func_id, args): + n = args.__len__() + _internal_assert(func_id == "tvm_assert_bound", "This function cannot be directly invoked!") + _internal_assert(n >= 1, "At least 1 argument should be provided.") + _internal_assert(n <= 3, "Accept at most 3 arguments.") + if n == 1: + return _make._OpAssertBound(args[0], None, None) + elif n == 2: + return _make._OpAssertBound(args[0], args[1], None) + return _make._OpAssertBound(*args) diff --git a/python/tvm/hybrid/preprocessor.py b/python/tvm/hybrid/preprocessor.py index 1a9de4e3f801..035e8a40f245 100644 --- a/python/tvm/hybrid/preprocessor.py +++ b/python/tvm/hybrid/preprocessor.py @@ -63,7 +63,7 @@ def visit_Call(self, node): _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \ ['range', 'max', 'min', 'len'] + \ list(self.symbols.keys()), \ - "Function call id not in intrinsics' list") + "Function call id " + func_id + " not in intrinsics' list") for elem in node.args: self.visit(elem) diff --git a/python/tvm/hybrid/runtime.py b/python/tvm/hybrid/runtime.py index aa00b4b80251..427ddf0779e7 100644 --- a/python/tvm/hybrid/runtime.py +++ b/python/tvm/hybrid/runtime.py @@ -110,36 +110,59 @@ def max_num_threads(allow_none=True): return target.current_target(allow_none).max_num_threads +def tvm_assert_bound(value, lower=None, upper=None): #pylint: disable=unused-argument + """ + Provide lower bound and upper bound for the value. + For now we simply return the value + + Parameters + ---------- + value: Expr + The bounded value + lower: Expr + lower bound (inclusive) + upper: Expr + upper bound (inclusive) + + Returns + ------- + res: Expr + same as value + """ + return value + + HYBRID_GLOBALS = { - 'unroll' : range, - 'vectorize' : range, - 'parallel' : range, - 'const_range' : range, - 'bind' : bind, - 'allocate' : allocate, - 'output_tensor' : allocate, - 'sqrt' : numpy.sqrt, - 'rsqrt' : rsqrt, - 'log' : numpy.log, - 'tanh' : numpy.tanh, - 'power' : numpy.power, - 'exp' : numpy.exp, - 'sigmoid' : sigmoid, - 'popcount' : popcount, - 'likely' : lambda cond: cond, - 'uint8' : numpy.uint8, - 'uint16' : numpy.uint16, - 'uint32' : numpy.uint32, - 'uint64' : numpy.uint64, - 'int8' : numpy.int8, - 'int16' : numpy.int16, - 'int32' : numpy.int32, - 'int64' : numpy.int64, - 'float16' : numpy.float16, - 'float32' : numpy.float32, - 'float64' : numpy.float64, - 'ceil_div' : lambda a, b: (a + b - 1) // b, - 'max_num_threads': max_num_threads + 'unroll' : range, + 'vectorize' : range, + 'parallel' : range, + 'const_range' : range, + 'bind' : bind, + 'allocate' : allocate, + 'output_tensor' : allocate, + 'sqrt' : numpy.sqrt, + 'rsqrt' : rsqrt, + 'log' : numpy.log, + 'tanh' : numpy.tanh, + 'power' : numpy.power, + 'exp' : numpy.exp, + 'sigmoid' : sigmoid, + 'popcount' : popcount, + 'likely' : lambda cond: cond, + 'uint8' : numpy.uint8, + 'uint16' : numpy.uint16, + 'uint32' : numpy.uint32, + 'uint64' : numpy.uint64, + 'int8' : numpy.int8, + 'int16' : numpy.int16, + 'int32' : numpy.int32, + 'int64' : numpy.int64, + 'float16' : numpy.float16, + 'float32' : numpy.float32, + 'float64' : numpy.float64, + 'ceil_div' : lambda a, b: (a + b - 1) // b, + 'max_num_threads' : max_num_threads, + 'tvm_assert_bound' : tvm_assert_bound } diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 03f37b171782..56dbbca7337f 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -193,7 +193,6 @@ TVM_REGISTER_API("make.Allocate") } \ }) - REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); @@ -225,6 +224,10 @@ TVM_REGISTER_API("make._OpIfThenElse") .set_body_typed([] (Expr cond, Expr true_value, Expr false_value) { return if_then_else(cond, true_value, false_value); }); +TVM_REGISTER_API("make._OpAssertBound") +.set_body_typed([] (Expr value, Expr lower, Expr upper) { + return assert_bound(value, lower, upper); +}); } // namespace ir } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 339b25a51894..3616ff617dcc 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -165,6 +165,7 @@ REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(LowerWarpMemory); REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(LowerIntrin); +REGISTER_PASS(RemoveIntrin); REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerTVMBuiltin); REGISTER_PASS(CombineContextCall); diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 0b84be291f71..daa9c281203a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -68,6 +68,41 @@ std::vector GetPath(Expr target, Expr expr) { return v.path_; } +class BoundRemover : public IRMutator { + public: + Expr Remove(Expr e) { + remove_bounded_ = true; + return IRMutator::Mutate(ir::Simplify(e)); + } + + Expr Reset(Expr e) { + remove_bounded_ = false; + return IRMutator::Mutate(e); + } + + Expr Mutate_(const Call* op, const Expr& e) final { + if (op->is_intrinsic(intrinsic::tvm_assert_bound) && remove_bounded_) { + Expr value = op->args[0]; + const Variable* var = value.as(); + CHECK(var) << "Invalid value in " << e << ". It should have been simplified."; + bounded_var_map_[var] = GetRef(op); + return value; + } + return IRMutator::Mutate_(op, e); + } + + Expr Mutate_(const Variable* op, const Expr& e) final { + if (!remove_bounded_ && bounded_var_map_.count(op)) { + return bounded_var_map_[op]; + } + return e; + } + + private: + bool remove_bounded_ = false; + std::unordered_map bounded_var_map_; +}; + enum CompareOp {kGreater, kLess, kEqual}; // a visitor to deduce the bound of a variable from a expression @@ -84,7 +119,7 @@ class BoundDeducer: public IRVisitor { void Visit(const ObjectRef& e) final { if (!success_) return; - if (e.get() == path_[iter_++]) { + if (iter_ < path_.size() && e.get() == path_[iter_++]) { IRVisitor::Visit(e); } else { success_ = false; @@ -295,6 +330,18 @@ void BoundDeducer::Transform() { void BoundDeducer::Deduce() { Init(); if (!success_) return; + + // Any variable appears in both expr and result, + // they should not be eagerly simplified according to its bound + // e.g., i + n/4 >= n + // => i >= n - n/4 + // If we eagerly simplified the left side given assert_bound(n, 0, +inf) + // we would get i + 0 >= n => i >= n, which is obviously incorrect. + // Thus we remove assert_bound here and reset later. + BoundRemover bound_remover; + expr_ = bound_remover.Remove(expr_); + result_ = bound_remover.Remove(result_); + Relax(); if (!success_) return; // get the path @@ -306,6 +353,9 @@ void BoundDeducer::Deduce() { expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); + + expr_ = bound_remover.Reset(expr_); + result_ = bound_remover.Reset(result_); } void BoundDeducer::Relax() { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 16e489a9c818..364a753b54fa 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -277,6 +277,11 @@ class ConstIntBoundAnalyzer::Impl : return VisitRightShift(op); } else if (op->is_intrinsic(Call::bitwise_and)) { return VisitBitwiseAnd(op); + } else if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr value = op->args[0]; + Entry lower = VisitExpr(op->args[1]); + Entry upper = VisitExpr(op->args[2]); + return MakeBound(lower.min_value, upper.max_value); } else { return Everything(op->dtype); } diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 79b39748426d..d41b8b049556 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -416,19 +416,19 @@ class IntervalSetEvaluator : } IntervalSet VisitExpr_(const Div* op) final { - return VisitBinaryExpr_(op); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const Mod* op) final { - return VisitBinaryExpr_(op); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const FloorDiv* op) final { - return VisitBinaryExpr_(op); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const FloorMod* op) final { - return VisitBinaryExpr_(op); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const Min* op) final { @@ -505,6 +505,24 @@ class IntervalSetEvaluator : return Union(analyzer_, false_set, true_set); } + IntervalSet VisitExpr_(const Call* op) final { + if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr expr = GetRef(op); + Expr value = op->args[0]; + Expr lb = op->args[1]; + Expr ub = op->args[2]; + // keep the assert_bound intrinsic in the interval, + // e.g., interval of assert_bound(n, 0, n) is [0, assert_bound(n, 0, n)] + // this makes sure variable n NEVER escape the assert_bound CallNode and appear standalone, + // it simplifies the rewrite simplification rules, + // e.g., no need to write things like TVM_TRY_REWRITE((x + y) - assert_bound(x, b1, b2), y) + lb = lb.same_as(value) ? expr : lb; + ub = ub.same_as(value) ? expr : ub; + return IntervalSet(lb, ub); + } + return VisitExprDefault_(op); + } + IntervalSet VisitExprDefault_(const Object* op) final { DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); @@ -517,6 +535,18 @@ class IntervalSetEvaluator : return set->min_value.same_as(value) && set->max_value.same_as(value); } + bool BoundedBySelf(const Expr& op) const { + if (const Call* call = op.as()) { + if (call->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr value = call->args[0]; + Expr lb = call->args[1]; + Expr ub = call->args[2]; + return lb.same_as(value) || ub.same_as(value); + } + } + return false; + } + template inline IntervalSet VisitBinaryExpr_(const T* op) { IntervalSet a = this->Eval(op->a); @@ -527,6 +557,22 @@ class IntervalSetEvaluator : return Combine(analyzer_, a, b); } + template + inline IntervalSet VisitDivExpr_(const T* op) { + IntervalSet a = this->Eval(op->a); + IntervalSet b = this->Eval(op->b); + if ((MatchPoint(a, op->a) && (MatchPoint(b, op->b) || BoundedBySelf(op->b))) + || (BoundedBySelf(op->a) && BoundedBySelf(op->b))) { + // e.g., + // div(10, 5) evaluates to 2 + // div(10, assert_bound(n, 0, n)) to itself + // div(assert_bound(m, 0, m), assert_bound(n, 0, n)) to itself + return IntervalSet::SinglePoint(GetRef(op)); + } + // e.g., div(assert_bound(m, 0, m), 2) goes here + return Combine(analyzer_, a, b); + } + // recursive depth int recur_depth_{0}; // analyzer diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 235306cc7bf8..3910d36ead32 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1716,6 +1716,13 @@ Mutate_(const Call* op, const Expr& self) { // the operator overload will eagerly constant fold. return op->args[0] & op->args[1]; } + } else if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr value = this->Mutate(op->args[0]); + if (const Call* v = value.as()) { + if (v->is_intrinsic(intrinsic::tvm_assert_bound)) { + return value; + } + } } if (op->is_intrinsic(Call::likely)) { for (const auto& constraint : literal_constraints_) { diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 3ea2cb77d316..dd45d622262d 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -510,6 +510,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fdevice.size(); ++i) { auto func = fdevice[i]; func = ir::LowerIntrin(func, target->target_name); + func = ir::RemoveIntrin(func); fdevice.Set(i, func); } @@ -531,6 +532,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::LowerIntrin(func, target_host->target_name); + func = ir::RemoveIntrin(func); func = ir::LowerDeviceStorageAccessInfo(func); func = ir::CombineContextCall(func); fhost.Set(i, func); diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 1166e7eef976..efc175233180 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include // Centralized header for constant folders. #include "../arithmetic/const_fold.h" @@ -265,7 +266,6 @@ Expr operator%(Expr a, Expr b) { return truncmod(a, b); } -// TODO(tqchen): switch to floordiv Expr indexdiv(Expr a, Expr b) { return floordiv(a, b); } @@ -626,4 +626,19 @@ Expr trunc(Expr x) { return ir::Call::make(x.dtype(), "trunc", {x}, ir::Call::PureIntrinsic); } +Expr assert_bound(Expr value, Expr lower, Expr upper) { + if (!value.as()) { + return value; + } else if (!lower.defined() && !upper.defined()) { + return value; + } + Expr lb = lower.defined() ? lower : value; + Expr ub = upper.defined() ? upper : value; + return ir::Call::make( + value.dtype(), + ir::intrinsic::tvm_assert_bound, + {value, lb, ub}, + ir::Call::PureIntrinsic); +} + } // namespace tvm diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index e4ff9cb457a5..8a3262d1b4d5 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -46,18 +46,30 @@ void BinderAddAssert(Expr cond, } } +const Expr GetVariable(const Expr& expr) { + if (expr.as()) { + return expr; + } else if (const auto* call = expr.as()) { + if (call->is_intrinsic(intrinsic::tvm_assert_bound)) { + return GetVariable(call->args[0]); + } + } + return Expr(); +} + bool ArgBinder::Bind_(const Expr& arg, const Expr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.dtype(), value.dtype()); - if (const Variable* v = arg.as()) { + Expr arg_as_var = GetVariable(arg); + if (const Variable* v = arg_as_var.as()) { auto it = def_map_->find(v); if (it == def_map_->end()) { - Var v_arg = Downcast(arg); + Var v_arg = Downcast(arg_as_var); defs_.emplace_back(v_arg); if (with_lets) { - (*def_map_)[v] = arg; + (*def_map_)[v] = arg_as_var; init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0))); } else { (*def_map_)[v] = value; diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc new file mode 100644 index 000000000000..5ffe6c0ecac5 --- /dev/null +++ b/src/pass/remove_intrin.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Remove intrinsic calls when possible. + * \file remove_intrin.cc + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +class IntrinRemover : public IRMutator { + public: + Expr Mutate_(const Call* op, const Expr& e) final { + if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + return op->args[0]; // simply return the value + } + return IRMutator::Mutate_(op, e); + } +}; + +Stmt RemoveIntrinStmt(Stmt stmt) { + return IntrinRemover().Mutate(stmt); +} + +Expr RemoveIntrinExpr(Expr expr) { + return IntrinRemover().Mutate(expr); +} + +LoweredFunc RemoveIntrin(LoweredFunc f) { + auto n = make_object(*f.operator->()); + n->body = RemoveIntrinStmt(n->body); + return LoweredFunc(n); +} + +// Register the api only for test purposes +TVM_REGISTER_API("ir_pass._RemoveIntrinStmt") +.set_body_typed(RemoveIntrinStmt); + +TVM_REGISTER_API("ir_pass._RemoveIntrinExpr") +.set_body_typed(RemoveIntrinExpr); + +} // namespace ir +} // namespace tvm + diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index d4baded91f7c..054191b26bcd 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -238,7 +238,7 @@ Map InferBound(const Schedule& sch) { InferRootBound(stage, ctx, &ret); // bind bound of root iter vars. - for (auto iv : stage->op->root_iter_vars()) { + for (auto iv : stage->op->root_iter_vars()) { auto it = ret.find(iv); if (it != ret.end()) { analyzer.Bind(iv->var, it->second); diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 33e31c766950..055b2b747229 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -213,8 +213,18 @@ def test_complex(a1, a2, coff): test_complex(2, 6, -4) +def test_deduce_assert_bound(): + i = tvm.var('i') + x = tvm.assert_bound(tvm.var('x'), 0) + + res = tvm.arith.DeduceBound(i, i+x < x, {}, {}) + assert str(res.min_value) == "neg_inf" + assert tvm.ir_pass.Simplify(res.max_value).value == -1 + + if __name__ == "__main__": test_check() test_deduce() test_deduce_basic() test_deduce_complex() + test_deduce_assert_bound() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 99c2942cd470..c38562228a46 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -833,6 +833,15 @@ def test_cast_simplify(): for i in [0, 1, 2, 3]: ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1)) +def test_assert_bound_simplify(): + ck = RewriteChecker() + x = tvm.var("x") + y = tvm.var("y") + ck.verify(tvm.assert_bound(tvm.assert_bound(x, 0), 0), tvm.assert_bound(x, 0)) + ck.verify(tvm.assert_bound(x, 0) + 1 >= 1, tvm.const(True, "bool")) + ck.verify(tvm.assert_bound(x, 0, 10) + 1 <= 11, tvm.const(True, "bool")) + ck.verify(tvm.floordiv(tvm.assert_bound(x, 0, 10), tvm.assert_bound(y, 0)) >= 0, tvm.const(True, "bool")) + if __name__ == "__main__": test_floordiv_index_simplify() test_floormod_index_simplify() @@ -849,3 +858,4 @@ def test_cast_simplify(): test_logical_simplify() test_let_simplify() test_cast_simplify() + test_assert_bound_simplify() diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 45ecf9539337..44532dcf50e8 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -50,6 +50,7 @@ def test_add_pipeline(): fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] # lower the floordiv(use stackvm rules so it works for all targets) fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits] + fsplits = [tvm.ir_pass.RemoveIntrin(x) for x in fsplits] fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) def check_target(device, host="stackvm"): diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 1f101a1e92e8..94556867e3ad 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -22,9 +22,10 @@ @pytest.mark.skip def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): + val = tvm.ir_pass._RemoveIntrinExpr(val) val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) + assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)), val return val.value ctx = tvm.context(target, 0) @@ -180,7 +181,7 @@ def fanout(n, a): assert isinstance(ir, tvm.stmt.For) assert ir.loop_var.name == 'i' assert ir.min.value == 0 - assert tvm.ir_pass.Equal(ir.extent, n - 3) + assert tvm.ir_pass.Equal(tvm.ir_pass._RemoveIntrinExpr(ir.extent), n - 3) #Check loopbody ibody = ir.body assert isinstance(ibody, tvm.stmt.AttrStmt) diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index c57f4a1109ec..0ddeec8e309a 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -187,6 +187,18 @@ def test_if_then_else(): raise ValueError('Unknown combinations') +def test_assert_bound(): + for dtype in ["int32", "int64"]: + var = tvm.var("var", dtype=dtype) + out = tvm.assert_bound(var, lower=0) + out = tvm.ir_pass._LowerIntrinStmt( + tvm.stmt.Evaluate( + tvm.floordiv(out, tvm.const(127, dtype)) + ), "c") + out = tvm.ir_pass._RemoveIntrinStmt(out) + assert tvm.ir_pass.Equal(out, tvm.stmt.Evaluate(tvm.truncdiv(var, 127))) + + if __name__ == "__main__": test_const_fold() test_const_fold2() @@ -194,3 +206,4 @@ def test_if_then_else(): test_const_fold4() test_binary_dtype_match() test_if_then_else() + test_assert_bound() diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 0a653066bff7..38d5ee951a9a 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -17,6 +17,7 @@ import pytest import tvm import pickle as pkl +from util import check_assert_bound def test_schedule_create(): m = tvm.var('m') @@ -164,7 +165,8 @@ def test_rfactor(): # normal schedule s = tvm.create_schedule(B.op) BF = s.rfactor(B, k1) - assert(tuple(BF.shape) == (n, n)) + assert(BF.shape[0] == n) + check_assert_bound(BF.shape[1], n, 0, n) assert(set(BF.op.body[0].axis) == set([k2])) assert(s[B].op.body[0].axis[0].dom.extent == n) assert(len(s[B].all_iter_vars) == 2) @@ -174,7 +176,7 @@ def test_rfactor(): xo, xi = s[B].split(B.op.axis[0], factor=8) BF = s.rfactor(B, ki) assert(BF.shape[0].value == 4) - assert(BF.shape[1] == n) + check_assert_bound(BF.shape[1], n, 0, n) assert(BF.op.body[0].axis[0] == k2) assert(BF.op.body[0].axis[1].var == ko.var) assert(s[B].op.body[0].axis[0].dom.extent.value == 4) @@ -183,7 +185,7 @@ def test_rfactor(): ko, ki = s[B].split(k1, factor=4) xo, xi = s[B].split(B.op.axis[0], factor=8) BF = s.rfactor(B, ki, 1) - assert(n == BF.shape[0]) + check_assert_bound(BF.shape[0], n, 0, n) assert(BF.shape[1].value == 4) assert(BF.op.body[0].axis[0] == k2) assert(BF.op.body[0].axis[1].var == ko.var) @@ -222,7 +224,7 @@ def test_tensor_intrin_scalar_params(): def intrin_func(ins, outs, sp): assert(isinstance(ins[0], tvm.schedule.Buffer)) - assert(ins[0].shape[0] == n) + check_assert_bound(ins[0].shape[0], n, 0, n) assert(sp[0] == v) assert(sp[1] == w) return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) @@ -232,7 +234,7 @@ def intrin_func(ins, outs, sp): assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) - assert(intrin.buffers[0].shape[0] == n) + check_assert_bound(intrin.buffers[0].shape[0], n, 0, n) assert tuple(intrin.scalar_params) == tuple((v, w)) A = tvm.placeholder((10,10), name='A') diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 44aca3b324bb..02124368e7f1 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -16,6 +16,7 @@ # under the License. import tvm from topi.nn.pooling import pool +from util import check_assert_bound def test_tensor(): m = tvm.var('m') @@ -26,7 +27,10 @@ def test_tensor(): T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) print(T) print(T.op.body) - assert(tuple(T.shape) == (m, n, l)) + assert(len(T.shape) == 3) + check_assert_bound(T.shape[0], m, 0, m) + check_assert_bound(T.shape[1], n, 0, n) + check_assert_bound(T.shape[2], l, 0, l) assert(isinstance(A.op, tvm.tensor.PlaceholderOp)) assert(A == A) assert(T.op.output(0) == T) @@ -182,7 +186,9 @@ def test_tensor_scan(): res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), tvm.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]), s) - assert tuple(res.shape) == (m, n) + assert len(res.shape) == 2 + check_assert_bound(res.shape[0], m, 0, m) + check_assert_bound(res.shape[1], n, 0, n) def test_scan_multi_out(): m = tvm.var("m") diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index 858b1e8a9153..d2cd39d97818 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +from util import check_assert_bound def test_copy2d(): m = tvm.var('m') @@ -29,10 +30,12 @@ def test_copy2d(): Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) def cb(src, dst, pad_before, pad_after, pad_value): - assert dst.strides[0] == l + check_assert_bound(dst.strides[0], l, 0, l) assert dst.strides[1].value == 1 - assert src.strides[0] == l - assert tuple(src.shape) == (m, l) + check_assert_bound(src.strides[0], l, 0, l) + assert len(src.shape) == 2 + check_assert_bound(src.shape[0], m, 0, m) + check_assert_bound(src.shape[1], l, 0, l) return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 021709506754..95f2264dbc49 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -65,6 +65,7 @@ def test_basic(): stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) assert('if' not in str(stmt.body.body.body.first)) + assert('if' in str(stmt.body.body.body.rest)) def test_const_loop(): n = 21 diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 9c3d1df17f2b..f49498f27a89 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +from util import check_assert_bound def test_bound1(): m = tvm.var('m') @@ -113,18 +114,20 @@ def test_bound_fusesplit1(): assert isinstance(bounds, tvm.container.Map) idxdiv = tvm.indexdiv assert(tvm.ir_pass.Simplify( - bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0) + tvm.ir_pass._RemoveIntrinExpr(bounds[A1.op.axis[0]].min) - idxdiv(xo * split1, l)).value == 0) expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1) + actual_extent = tvm.ir_pass._RemoveIntrinExpr(bounds[A1.op.axis[0]].extent) for i in range(1, 6): for j in range(1, 6): for k in range(1, 6): vars = tvm.convert({split1: tvm.const(i, "int32"), l: tvm.const(j, "int32"), xo.var: tvm.const(k, "int32")}) - comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value + comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(actual_extent, vars)).value exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value assert(comp_ext == exp_ext) - assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) + l_extent = tvm.ir_pass._RemoveIntrinExpr(bounds[A1.op.axis[1]].extent) + assert(tvm.ir_pass.Simplify(l_extent - l).value == 0) def test_bound_fusesplit2(): m = tvm.var("m") @@ -179,7 +182,10 @@ def test_bound_scan(): s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_scan = tvm.scan(s_init, s_update, s_state) - assert tuple(s_scan.shape) == (m, n) + assert len(s_scan.shape) == 2 + check_assert_bound(s_scan.shape[0], m, 0, m) + check_assert_bound(s_scan.shape[1], n, 0, n) + s = tvm.create_schedule(s_scan.op) XX = s.cache_read(X, "local", s_update) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) @@ -247,7 +253,7 @@ def test_bound_group_schedule(): s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 - assert bounds[x.op.axis[1]].extent == n + check_assert_bound(bounds[x.op.axis[1]].extent, n, 0, n) def test_bound_nest_group(): m = tvm.var("m") @@ -267,7 +273,7 @@ def test_bound_nest_group(): assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent.value == 1 assert bounds[x1.op.axis[0]].extent.value == 1 - assert bounds[x1.op.axis[1]].extent == n + check_assert_bound(bounds[x1.op.axis[1]].extent, n, 0, n) def test_bound_nest_thread(): @@ -294,7 +300,7 @@ def test_bound_nest_thread(): bounds = tvm.schedule.InferBound(s) assert(bounds[A1.op.axis[0]].extent.value==1) assert(bounds[A2.op.axis[0]].extent.value==32) - assert(bounds[A3.op.axis[0]].extent == m) + check_assert_bound(bounds[A3.op.axis[0]].extent, m, 0, m) def test_gemm_bound(): nn = 1024 diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 5275aec4db90..42619b102289 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -16,6 +16,7 @@ # under the License. import tvm import numpy as np +from util import check_assert_bound def test_schedule0(): m = tvm.var('m') @@ -67,7 +68,10 @@ def test_schedule_scan(): s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i]) res = tvm.scan(s_init, s_update, s_state) - assert tuple(res.shape) == (m, n) + assert len(res.shape) == 2 + check_assert_bound(res.shape[0], m, 0, m) + check_assert_bound(res.shape[1], n, 0, n) + s = tvm.create_schedule(res.op) s = s.normalize() ir = tvm.lower(s, [s_state], simple_mode=True) diff --git a/tests/python/unittest/util.py b/tests/python/unittest/util.py new file mode 100644 index 000000000000..7734654be396 --- /dev/null +++ b/tests/python/unittest/util.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from topi.util import get_const_int + + +def check_assert_bound(expr, var, lb, ub): + assert isinstance(expr, tvm.expr.Call) + assert expr.name == "tvm_assert_bound" + assert expr.dtype == var.dtype + assert expr.args[0] == var + lower = get_const_int(expr.args[1]) if isinstance(expr.args[1], (tvm.expr.IntImm, tvm.expr.UIntImm)) \ + else expr.args[1] + upper = get_const_int(expr.args[2]) if isinstance(expr.args[2], (tvm.expr.IntImm, tvm.expr.UIntImm)) \ + else expr.args[2] + assert lower == lb + assert upper == ub diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 5af30335a9c5..bfe84983c6d7 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -140,18 +140,18 @@ def conv2d_infer_layout(workload, cfg): def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): """ Get the workload structure. """ if data_layout == 'NCHW': - _, CI, IH, IW = [x.value for x in data.shape] + _, CI, IH, IW = get_const_tuple(data.shape) elif data_layout == 'NHWC': - _, IH, IW, CI = [x.value for x in data.shape] + _, IH, IW, CI = get_const_tuple(data.shape) elif data_layout == 'HWCN': - IH, IW, CI, _ = [x.value for x in data.shape] + IH, IW, CI, _ = get_const_tuple(data.shape) else: raise ValueError("not support this layout {} yet".format(data_layout)) if data_layout == 'NCHW': - CO, CIG, KH, KW = [x.value for x in kernel.shape] + CO, CIG, KH, KW = get_const_tuple(kernel.shape) else: - KH, KW, CIG, CO = [x.value for x in kernel.shape] + KH, KW, CIG, CO = get_const_tuple(kernel.shape) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) GRPS = CI // CIG diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index e25e85dac05e..fa24f37cc62b 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -143,6 +143,24 @@ def equal_const_int(expr, value): return expr.value == value +def is_var(expr): + """Check whether the input is tvm.expr.Var or tvm_assert_bound intrinsic. + + Parameters + ---------- + expr : tvm.Expr + The input expression. + + Returns + ------- + equal : bool + Whether it is tvm.expr.Var or + tvm_assert_bound intrinsic (which provides the boundary information of a Var). + """ + return isinstance(expr, tvm.expr.Var) \ + or (isinstance(expr, tvm.expr.Call) and expr.name == "tvm_assert_bound") + + def get_const_tuple(in_tuple): """Verifies input tuple is IntImm or Var, returns tuple of int or Var. @@ -158,7 +176,7 @@ def get_const_tuple(in_tuple): """ ret = [] for elem in in_tuple: - if isinstance(elem, tvm.expr.Var): + if is_var(elem): ret.append(elem) elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): elem = tvm.ir_pass.Simplify(elem) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 0e284da17ee6..59fc7389260d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -30,7 +30,7 @@ conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.pad import pad -from ..util import get_const_tuple +from ..util import get_const_tuple, is_var from . import conv2d_avx_1x1, conv2d_avx_common @@ -43,7 +43,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth """ static_data_shape = [] for dim in get_const_tuple(data.shape): - if isinstance(dim, tvm.expr.Var): + if is_var(dim): static_data_shape.append(1) else: static_data_shape.append(dim) diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index b7a3d6d5a330..dcd7a657887e 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -24,7 +24,7 @@ from .util import get_fp32_len from .. import generic, tag, nn -from ..util import traverse_inline, get_const_tuple +from ..util import traverse_inline, get_const_tuple, is_var @autotvm.register_topi_compute(nn.dense, "cpu", "direct") def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): @@ -40,7 +40,7 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): # Always use dense_nopack for dynamic input. # This is a temporary for CV models. # TODO(kevinthesun): use kernel dispatcher instead. - if isinstance(M, tvm.expr.Var): + if is_var(M): return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype) # For small batch sizes, don't pack weight into cache-friendly layout @@ -59,9 +59,9 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) # batch, in_dim N, _ = get_const_tuple(weight.shape) # out_dim # create tuning space - cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=3) - cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=3) - cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2) + cfg.define_split("tile_y", 32 if is_var(M) else M, num_outputs=3) + cfg.define_split("tile_x", 32 if is_var(N) else N, num_outputs=3) + cfg.define_split("tile_k", 32 if is_var(K) else K, num_outputs=2) if cfg.is_fallback: _default_dense_pack_config(cfg, M, N, K) @@ -93,9 +93,9 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) # create tuning space - cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2) - cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2) - cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2) + cfg.define_split("tile_y", 32 if is_var(M) else M, num_outputs=2) + cfg.define_split("tile_x", 32 if is_var(N) else N, num_outputs=2) + cfg.define_split("tile_k", 32 if is_var(K) else K, num_outputs=2) if cfg.is_fallback: _default_dense_nopack_config(cfg, M, N, K) @@ -218,11 +218,11 @@ def _schedule_dense_nopack_template(cfg, s, C): def _default_dense_pack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.expr.Var): + if is_var(M): M = 16 - if isinstance(N, tvm.expr.Var): + if is_var(N): N = 16 - if isinstance(K, tvm.expr.Var): + if is_var(K): K = 16 vec_width = get_fp32_len() @@ -255,11 +255,11 @@ def _default_dense_pack_config(cfg, M, N, K): def _default_dense_nopack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.expr.Var): + if is_var(M): M = 16 - if isinstance(N, tvm.expr.Var): + if is_var(N): N = 16 - if isinstance(K, tvm.expr.Var): + if is_var(K): K = 16 vec_width = get_fp32_len()