From 2a4357383a536370d337b36ff59bc8a7f298fb80 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 3 Dec 2019 17:38:39 -0800 Subject: [PATCH 01/24] add AssertLowerBound --- include/tvm/ir.h | 16 ++++++++++++++++ include/tvm/ir_functor_ext.h | 2 ++ include/tvm/ir_mutator.h | 1 + include/tvm/ir_visitor.h | 1 + python/tvm/api.py | 1 + python/tvm/expr.py | 17 +++++++++++++++++ src/api/api_ir.cc | 1 + src/lang/attr_functor.h | 2 ++ src/lang/ir.cc | 23 +++++++++++++++++++++++ src/pass/ir_mutator.cc | 6 ++++++ src/pass/ir_visitor.cc | 5 +++++ src/pass/lower_intrin.cc | 1 + 12 files changed, 76 insertions(+) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index aca45f46c0b3..997d265511bb 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -469,6 +469,22 @@ class Let : public ExprNode { TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode); }; +class AssertLowerBound : public ExprNode { + public: + Expr value; + Expr bound; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("bound", &bound); + } + + TVM_DLL static Expr make(Expr value, Expr bound); + + static constexpr const char* _type_key = "AssertLowerBound"; + TVM_DECLARE_NODE_TYPE_INFO(AssertLowerBound, ExprNode); +}; + // Call node, represent a function call or a multi-dimensional array load. // // TODO(tvm-team): diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 04ce7934ff2f..f6db42a876b4 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -164,6 +164,7 @@ class ExprFunctor { virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AssertLowerBound* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); @@ -206,6 +207,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(UIntImm); IR_EXPR_FUNCTOR_DISPATCH(FloatImm); IR_EXPR_FUNCTOR_DISPATCH(StringImm); + IR_EXPR_FUNCTOR_DISPATCH(AssertLowerBound); return vtable; } }; diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 5460ae0f4ba9..57c772bd6738 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -121,6 +121,7 @@ class TVM_DLL IRMutator { virtual Expr Mutate_(const FloatImm* op, const Expr& e); virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const Shuffle* op, const Expr& e); + virtual Expr Mutate_(const AssertLowerBound* op, const Expr& e); }; /*! diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index b85cf233a42f..95b6564ae7f4 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -144,6 +144,7 @@ class TVM_DLL IRVisitor { virtual void Visit_(const UIntImm* op); virtual void Visit_(const FloatImm* op); virtual void Visit_(const StringImm* op); + virtual void Visit_(const AssertLowerBound* op); }; /*! diff --git a/python/tvm/api.py b/python/tvm/api.py index f0261be37e41..76e63e858d84 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -257,6 +257,7 @@ def placeholder(shape, dtype=None, name="placeholder"): The created tensor """ shape = (shape,) if isinstance(shape, _expr.Expr) else shape + shape = tuple(_make.AssertLowerBound(size, 0) for size in shape) dtype = float32 if dtype is None else dtype return _api_internal._Placeholder( shape, dtype, name) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 733f57a68c56..9c833dd61d4a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -860,3 +860,20 @@ class Let(Expr): def __init__(self, var, value, body): self.__init_handle_by_constructor__( _make.Let, var, value, body) + + +@register_node +class AssertLowerBound(Expr): + """AssertLowerBound node. + + Parameters + ---------- + value : Expr + The value in to be asserted. + + bound : Expr + The lower bound of the value. + """ + def __init__(self, value, bound): + self.__init_handle_by_constructor__( + _make.AssertLowerBound, value, bound) diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 9312c5532302..4a3151584fbe 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -152,6 +152,7 @@ REGISTER_MAKE(Broadcast); REGISTER_MAKE(Shuffle); REGISTER_MAKE(Let); REGISTER_MAKE(LetStmt); +REGISTER_MAKE(AssertLowerBound); REGISTER_MAKE(AssertStmt); REGISTER_MAKE(ProducerConsumer); REGISTER_MAKE(Provide); diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 51b355e81df3..842f9d1b0000 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -103,6 +103,7 @@ class AttrFunctor { virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::AssertLowerBound* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. @@ -138,6 +139,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(Cast); ATTR_FUNCTOR_DISPATCH(Call); ATTR_FUNCTOR_DISPATCH(Select); + ATTR_FUNCTOR_DISPATCH(AssertLowerBound); return vtable; } }; diff --git a/src/lang/ir.cc b/src/lang/ir.cc index bb8401dae843..e55fb8bf8efc 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -176,6 +176,18 @@ Expr Let::make(Var var, Expr value, Expr body) { return Expr(node); } +Expr AssertLowerBound::make(Expr value, Expr bound) { + CHECK(value.defined()); + CHECK(bound.defined()); + CHECK_EQ(value.type(), bound.type()); + + NodePtr node = make_node(); + node->type = value.type(); + node->value = std::move(value); + node->bound = std::move(bound); + return Expr(node); +} + const char* Call::vectorizable_intrinsics[] = { "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", "log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right, @@ -835,6 +847,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ObjectRef& node, IRPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "assert_lower_bound("; + p->Print(op->value); + p->stream << ", "; + p->Print(op->bound); + p->stream << ")"; +}); + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter* p) { auto* op = static_cast(node.get()); @@ -1203,6 +1225,7 @@ TVM_REGISTER_NODE_TYPE(Shuffle); TVM_REGISTER_NODE_TYPE(Prefetch); TVM_REGISTER_NODE_TYPE(Call); TVM_REGISTER_NODE_TYPE(Let); +TVM_REGISTER_NODE_TYPE(AssertLowerBound); TVM_REGISTER_NODE_TYPE(LetStmt); TVM_REGISTER_NODE_TYPE(AssertStmt); TVM_REGISTER_NODE_TYPE(ProducerConsumer); diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index f79a1ab8fe3b..c89a727feea2 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -487,6 +487,11 @@ Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) { } } +Expr IRMutator::Mutate_(const AssertLowerBound *op, const Expr& e) { + Expr value = this->Mutate(op->value); + return value; +} + #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \ return e; \ @@ -529,6 +534,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(UIntImm) .DISPATCH_TO_MUTATE_EXPR(FloatImm) .DISPATCH_TO_MUTATE_EXPR(StringImm) +.DISPATCH_TO_MUTATE_EXPR(AssertLowerBound) .DISPATCH_TO_MUTATE_EXPR(Shuffle); } // namespace ir diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 204c0f75fe4a..deabe4ae0b4e 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -127,6 +127,10 @@ void IRVisitor::Visit_(const Call *op) { VisitArray(op->args, this); } +void IRVisitor::Visit_(const AssertLowerBound *op) { + this->Visit(op->value); +} + #define DEFINE_BINOP_VISIT_(OP) \ void IRVisitor::Visit_(const OP* op) { \ this->Visit(op->a); \ @@ -277,6 +281,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Ramp) .DISPATCH_TO_VISIT(Shuffle) .DISPATCH_TO_VISIT(Broadcast) +.DISPATCH_TO_VISIT(AssertLowerBound) .DISPATCH_TO_VISIT(AssertStmt) .DISPATCH_TO_VISIT(ProducerConsumer) .DISPATCH_TO_VISIT(Provide) diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index c2a2fe6f5942..d4818b47d1ea 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -284,6 +284,7 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { auto n = make_node(*f.operator->()); n->body = LowerIntrinStmt(n->body, target); +// LOG(INFO) << "after lower intrin " << n->body; return LoweredFunc(n); } From eb63b3a18c2b7926139067dd0a2996c364530ee3 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 9 Dec 2019 09:17:15 -0800 Subject: [PATCH 02/24] add AssertLowerBound to ir_deep_compare.cc --- src/pass/ir_deep_compare.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index cb859d07f07b..3f3b07475f9b 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -246,6 +246,10 @@ class IRDeepCompare : if (CompareArray(op->source, rhs->source) != 0) return; } + void VisitExpr_(const AssertLowerBound *op, const Expr& other) final { + VisitExpr(op->value, other); + } + void VisitExpr_(const IntImm *op, const Expr& other) final { CompareValue(op->value, other.as()->value); } From d1d3047b80ca9c3ae97d3bd4ff574ad4451db27a Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 9 Dec 2019 14:45:30 -0800 Subject: [PATCH 03/24] Make Range for AsserLowerBound; add support in codegen --- include/tvm/ir.h | 4 ++-- python/tvm/api.py | 1 + src/arithmetic/const_int_bound.cc | 4 ++++ src/codegen/codegen_c.cc | 4 ++++ src/codegen/codegen_c.h | 1 + src/codegen/llvm/codegen_llvm.cc | 4 ++++ src/codegen/llvm/codegen_llvm.h | 1 + src/lang/ir.cc | 10 ++++------ src/pass/arg_binder.cc | 16 +++++++++++++--- src/pass/ir_mutator.cc | 2 +- 10 files changed, 35 insertions(+), 12 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 997d265511bb..5b54d6180e5a 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -472,14 +472,14 @@ class Let : public ExprNode { class AssertLowerBound : public ExprNode { public: Expr value; - Expr bound; + int64_t bound; void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); v->Visit("bound", &bound); } - TVM_DLL static Expr make(Expr value, Expr bound); + TVM_DLL static Expr make(Expr value, int64_t bound); static constexpr const char* _type_key = "AssertLowerBound"; TVM_DECLARE_NODE_TYPE_INFO(AssertLowerBound, ExprNode); diff --git a/python/tvm/api.py b/python/tvm/api.py index 76e63e858d84..f0b867a48bd5 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -297,6 +297,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): shape = (shape,) if isinstance(shape, _expr.Expr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) + shape = tuple(_make.AssertLowerBound(size, 0) for size in shape) ndim = len(shape) code = fcompute.__code__ diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 6e119695a8c8..78d3c287f55b 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -282,6 +282,10 @@ class ConstIntBoundAnalyzer::Impl : } } + Entry VisitExpr_(const AssertLowerBound* op) final { + return MakeBound(op->bound, kPosInf); + } + Entry VisitExpr_(const Variable* op) final { Var v = GetRef(op); auto it = var_map_.find(v); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index eab542dd3e08..240b03927b9a 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -405,6 +405,10 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N } } +void CodeGenC::VisitExpr_(const AssertLowerBound *op, std::ostream& os) { + PrintExpr(op->value, os); +} + void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 8701cda1e14c..9b37843d1e1e 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -131,6 +131,7 @@ class CodeGenC : void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AssertLowerBound* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const Store* op) override; diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 2cff88b0bbf4..97a56f41ea6c 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1043,6 +1043,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) { return CreateBroadcast(MakeValue(op->value), op->lanes); } +llvm::Value* CodeGenLLVM::VisitExpr_(const AssertLowerBound* op) { + return this->VisitExpr(op->value); +} + void CodeGenLLVM::VisitStmt_(const Store* op) { CHECK(is_one(op->predicate)); Type t = op->value.type(); diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index b7d091b3921b..8be6af32ce51 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -132,6 +132,7 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const Ramp* op) override; llvm::Value* VisitExpr_(const Shuffle* op) override; llvm::Value* VisitExpr_(const Broadcast* op) override; + llvm::Value* VisitExpr_(const AssertLowerBound* op) override; // stmt void VisitStmt_(const Store* op) override; void VisitStmt_(const For* op) override; diff --git a/src/lang/ir.cc b/src/lang/ir.cc index e55fb8bf8efc..8e4bcfb54ec7 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -176,15 +176,14 @@ Expr Let::make(Var var, Expr value, Expr body) { return Expr(node); } -Expr AssertLowerBound::make(Expr value, Expr bound) { +Expr AssertLowerBound::make(Expr value, int64_t bound) { CHECK(value.defined()); - CHECK(bound.defined()); - CHECK_EQ(value.type(), bound.type()); +// CHECK_EQ(value.type(), bound.type()); NodePtr node = make_node(); node->type = value.type(); node->value = std::move(value); - node->bound = std::move(bound); + node->bound = bound; return Expr(node); } @@ -852,8 +851,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) auto* op = static_cast(node.get()); p->stream << "assert_lower_bound("; p->Print(op->value); - p->stream << ", "; - p->Print(op->bound); + p->stream << ", " << op->bound; p->stream << ")"; }); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index f892b6b957f8..bec81d9ea890 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -46,18 +46,28 @@ void BinderAddAssert(Expr cond, } } +const Expr GetVariable(const Expr& expr) { + if (expr.as()) { + return expr; + } else if (const auto* v = expr.as()) { + return GetVariable(v->value); + } + return Expr(); +} + bool ArgBinder::Bind_(const Expr& arg, const Expr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.type(), value.type()); - if (const Variable* v = arg.as()) { + Expr arg_as_var = GetVariable(arg); + if (const Variable* v = arg_as_var.as()) { auto it = def_map_->find(v); if (it == def_map_->end()) { - Var v_arg = Downcast(arg); + Var v_arg = Downcast(arg_as_var); defs_.emplace_back(v_arg); if (with_lets) { - (*def_map_)[v] = arg; + (*def_map_)[v] = arg_as_var; init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0))); } else { (*def_map_)[v] = value; diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index c89a727feea2..c6683fb62822 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -489,7 +489,7 @@ Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) { Expr IRMutator::Mutate_(const AssertLowerBound *op, const Expr& e) { Expr value = this->Mutate(op->value); - return value; + return AssertLowerBound::make(value, 0); } #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ From f3e330593ffd5e746eac97d1c0a31d57099c2cad Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 9 Dec 2019 16:02:48 -0800 Subject: [PATCH 04/24] Expr bound --- include/tvm/ir.h | 4 ++-- src/arithmetic/const_int_bound.cc | 3 ++- src/lang/ir.cc | 10 ++++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 5b54d6180e5a..997d265511bb 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -472,14 +472,14 @@ class Let : public ExprNode { class AssertLowerBound : public ExprNode { public: Expr value; - int64_t bound; + Expr bound; void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); v->Visit("bound", &bound); } - TVM_DLL static Expr make(Expr value, int64_t bound); + TVM_DLL static Expr make(Expr value, Expr bound); static constexpr const char* _type_key = "AssertLowerBound"; TVM_DECLARE_NODE_TYPE_INFO(AssertLowerBound, ExprNode); diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 78d3c287f55b..25c22e7f2fe1 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -283,7 +283,8 @@ class ConstIntBoundAnalyzer::Impl : } Entry VisitExpr_(const AssertLowerBound* op) final { - return MakeBound(op->bound, kPosInf); + Entry bound = VisitExpr(op->bound); + return MakeBound(bound.max_value, kPosInf); } Entry VisitExpr_(const Variable* op) final { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 8e4bcfb54ec7..e55fb8bf8efc 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -176,14 +176,15 @@ Expr Let::make(Var var, Expr value, Expr body) { return Expr(node); } -Expr AssertLowerBound::make(Expr value, int64_t bound) { +Expr AssertLowerBound::make(Expr value, Expr bound) { CHECK(value.defined()); -// CHECK_EQ(value.type(), bound.type()); + CHECK(bound.defined()); + CHECK_EQ(value.type(), bound.type()); NodePtr node = make_node(); node->type = value.type(); node->value = std::move(value); - node->bound = bound; + node->bound = std::move(bound); return Expr(node); } @@ -851,7 +852,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) auto* op = static_cast(node.get()); p->stream << "assert_lower_bound("; p->Print(op->value); - p->stream << ", " << op->bound; + p->stream << ", "; + p->Print(op->bound); p->stream << ")"; }); From d8014593cde196e2cbc921c0874e589581070476 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 11 Dec 2019 15:48:16 -0800 Subject: [PATCH 05/24] use intrinsic assert_bound --- docs/api/python/dev.rst | 1 + docs/api/python/tvm.rst | 2 + include/tvm/expr_operator.h | 10 ++++ include/tvm/ir.h | 26 ++++----- include/tvm/ir_functor_ext.h | 2 - include/tvm/ir_mutator.h | 1 - include/tvm/ir_pass.h | 7 +++ include/tvm/ir_visitor.h | 1 - python/tvm/api.py | 25 ++++++++- python/tvm/build_module.py | 1 + python/tvm/expr.py | 17 ------ src/api/api_ir.cc | 6 ++- src/api/api_pass.cc | 1 + src/arithmetic/const_int_bound.cc | 10 ++-- src/codegen/codegen_c.cc | 4 -- src/codegen/codegen_c.h | 1 - src/codegen/llvm/codegen_llvm.cc | 4 -- src/codegen/llvm/codegen_llvm.h | 1 - src/lang/attr_functor.h | 2 - src/lang/expr_operator.cc | 14 +++++ src/lang/ir.cc | 23 -------- src/pass/arg_binder.cc | 6 ++- src/pass/ir_deep_compare.cc | 4 -- src/pass/ir_mutator.cc | 6 --- src/pass/ir_visitor.cc | 5 -- src/pass/lower_intrin.cc | 1 - src/pass/remove_intrin.cc | 60 +++++++++++++++++++++ tests/python/unittest/test_lang_operator.py | 13 +++++ 28 files changed, 155 insertions(+), 99 deletions(-) create mode 100644 src/pass/remove_intrin.cc diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index 7bb938ca7517..2c6bec4b2e1b 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -73,6 +73,7 @@ tvm.ir_pass tvm.ir_pass.SplitPipeline tvm.ir_pass.LowerThreadAllreduce tvm.ir_pass.LowerIntrin + tvm.ir_pass.RemoveIntrin tvm.ir_pass.LowerTVMBuiltin tvm.ir_pass.NarrowChannelAccess diff --git a/docs/api/python/tvm.rst b/docs/api/python/tvm.rst index b517195db9e4..2636a7b2b945 100644 --- a/docs/api/python/tvm.rst +++ b/docs/api/python/tvm.rst @@ -45,6 +45,7 @@ The user facing API for computation declaration. tvm.min tvm.max tvm.tag_scope + tvm.assert_bound .. autofunction:: tvm.load_json .. autofunction:: tvm.save_json @@ -70,3 +71,4 @@ The user facing API for computation declaration. .. autofunction:: tvm.min .. autofunction:: tvm.max .. autofunction:: tvm.tag_scope +.. autofunction:: tvm.assert_bound diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 625ee8e49286..380a974b1ca0 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -583,6 +583,16 @@ TVM_DLL Expr nearbyint(Expr x); */ TVM_DLL Expr trunc(Expr x); +/*! + * \brief Pass bound information of value. + * \param value The input expression. + * \param lower The lower bound of value (inclusive). + * \param upper The upper bound of value (inclusive). + * \return The Call node indicates lower and upper bound of input expression. + * This intrinsic will be removed before codegen. + */ +TVM_DLL Expr assert_bound(Expr value, Expr lower, Expr upper); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 997d265511bb..a6d1539f6f8e 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -469,22 +469,6 @@ class Let : public ExprNode { TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode); }; -class AssertLowerBound : public ExprNode { - public: - Expr value; - Expr bound; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("bound", &bound); - } - - TVM_DLL static Expr make(Expr value, Expr bound); - - static constexpr const char* _type_key = "AssertLowerBound"; - TVM_DECLARE_NODE_TYPE_INFO(AssertLowerBound, ExprNode); -}; - // Call node, represent a function call or a multi-dimensional array load. // // TODO(tvm-team): @@ -1629,6 +1613,16 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; */ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; +/*! + * \brief tvm intrinsic for passing bound information of the variables. + * It simply represents the value, while it helps BoundAnalyzer + * understand the upper and lower bound of the value. + * Expr tvm_assert_bound(Expr value, Expr lower_bound, Expr upper_bound) { + * return value; + * } + */ +constexpr const char* tvm_assert_bound = "tvm_assert_bound"; + } // namespace intrinsic /*! diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index f6db42a876b4..04ce7934ff2f 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -164,7 +164,6 @@ class ExprFunctor { virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const AssertLowerBound* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); @@ -207,7 +206,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(UIntImm); IR_EXPR_FUNCTOR_DISPATCH(FloatImm); IR_EXPR_FUNCTOR_DISPATCH(StringImm); - IR_EXPR_FUNCTOR_DISPATCH(AssertLowerBound); return vtable; } }; diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 57c772bd6738..5460ae0f4ba9 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -121,7 +121,6 @@ class TVM_DLL IRMutator { virtual Expr Mutate_(const FloatImm* op, const Expr& e); virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const Shuffle* op, const Expr& e); - virtual Expr Mutate_(const AssertLowerBound* op, const Expr& e); }; /*! diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5c5c4bb2f452..ff4f8c641992 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -544,6 +544,13 @@ LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func); */ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); +/*! + * \brief Remove intrinsic function calls if possible. + * \param f The function to be processed. + * \return Transformed function. + */ +LoweredFunc RemoveIntrin(LoweredFunc f); + /*! * \brief Lower custom datatypes. * diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index 95b6564ae7f4..b85cf233a42f 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -144,7 +144,6 @@ class TVM_DLL IRVisitor { virtual void Visit_(const UIntImm* op); virtual void Visit_(const FloatImm* op); virtual void Visit_(const StringImm* op); - virtual void Visit_(const AssertLowerBound* op); }; /*! diff --git a/python/tvm/api.py b/python/tvm/api.py index f0b867a48bd5..32ad392a24f7 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -257,7 +257,7 @@ def placeholder(shape, dtype=None, name="placeholder"): The created tensor """ shape = (shape,) if isinstance(shape, _expr.Expr) else shape - shape = tuple(_make.AssertLowerBound(size, 0) for size in shape) + shape = tuple(assert_bound(size, 0, None) for size in shape) dtype = float32 if dtype is None else dtype return _api_internal._Placeholder( shape, dtype, name) @@ -297,7 +297,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): shape = (shape,) if isinstance(shape, _expr.Expr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) - shape = tuple(_make.AssertLowerBound(size, 0) for size in shape) + shape = tuple(assert_bound(size, 0, None) for size in shape) ndim = len(shape) code = fcompute.__code__ @@ -1049,6 +1049,27 @@ def floormod(a, b): return _make._OpFloorMod(a, b) +def assert_bound(value, lower=None, upper=None): + """Pass bound information of value. + + Parameters + ---------- + value : Expr + The input expression. + lower : Expr + The lower bound of value (inclusive). Default +inf + upper : Expr + The upper bound of value (inclusive). Default -inf + + Returns + ------- + res : Expr + Call node indicates lower and upper bound of input expression. + This intrinsic will be removed before codegen. + """ + return _make._OpAssertBound(value, lower, upper) + + _init_api("tvm.api") #pylint: disable=unnecessary-lambda diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index f96e28323595..02569c949f7a 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -500,6 +500,7 @@ def _build_for_device(flist, target, target_host): fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] + fhost = [ir_pass.RemoveIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] mdev = codegen.build_module(fdevice, str(target)) if fdevice else None diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 9c833dd61d4a..733f57a68c56 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -860,20 +860,3 @@ class Let(Expr): def __init__(self, var, value, body): self.__init_handle_by_constructor__( _make.Let, var, value, body) - - -@register_node -class AssertLowerBound(Expr): - """AssertLowerBound node. - - Parameters - ---------- - value : Expr - The value in to be asserted. - - bound : Expr - The lower bound of the value. - """ - def __init__(self, value, bound): - self.__init_handle_by_constructor__( - _make.AssertLowerBound, value, bound) diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 4a3151584fbe..6ed69e0e5a08 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -152,7 +152,6 @@ REGISTER_MAKE(Broadcast); REGISTER_MAKE(Shuffle); REGISTER_MAKE(Let); REGISTER_MAKE(LetStmt); -REGISTER_MAKE(AssertLowerBound); REGISTER_MAKE(AssertStmt); REGISTER_MAKE(ProducerConsumer); REGISTER_MAKE(Provide); @@ -194,7 +193,6 @@ TVM_REGISTER_API("make.Allocate") } \ }) - REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); @@ -226,6 +224,10 @@ TVM_REGISTER_API("make._OpIfThenElse") .set_body_typed([] (Expr cond, Expr true_value, Expr false_value) { return if_then_else(cond, true_value, false_value); }); +TVM_REGISTER_API("make._OpAssertBound") +.set_body_typed([] (Expr value, Expr lower, Expr upper) { + return assert_bound(value, lower, upper); +}); } // namespace ir } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index bcafe0904ed2..6b892b8ffebf 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -166,6 +166,7 @@ REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(LowerWarpMemory); REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(LowerIntrin); +REGISTER_PASS(RemoveIntrin); REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerTVMBuiltin); REGISTER_PASS(CombineContextCall); diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 25c22e7f2fe1..d22851393564 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -277,16 +277,16 @@ class ConstIntBoundAnalyzer::Impl : return VisitRightShift(op); } else if (op->is_intrinsic(Call::bitwise_and)) { return VisitBitwiseAnd(op); + } else if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr value = op->args[0]; + Entry lower = VisitExpr(op->args[1]); + Entry upper = VisitExpr(op->args[2]); + return MakeBound(lower.min_value, upper.max_value); } else { return Everything(op->type); } } - Entry VisitExpr_(const AssertLowerBound* op) final { - Entry bound = VisitExpr(op->bound); - return MakeBound(bound.max_value, kPosInf); - } - Entry VisitExpr_(const Variable* op) final { Var v = GetRef(op); auto it = var_map_.find(v); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 240b03927b9a..eab542dd3e08 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -405,10 +405,6 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N } } -void CodeGenC::VisitExpr_(const AssertLowerBound *op, std::ostream& os) { - PrintExpr(op->value, os); -} - void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 9b37843d1e1e..8701cda1e14c 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -131,7 +131,6 @@ class CodeGenC : void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AssertLowerBound* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const Store* op) override; diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 97a56f41ea6c..2cff88b0bbf4 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1043,10 +1043,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) { return CreateBroadcast(MakeValue(op->value), op->lanes); } -llvm::Value* CodeGenLLVM::VisitExpr_(const AssertLowerBound* op) { - return this->VisitExpr(op->value); -} - void CodeGenLLVM::VisitStmt_(const Store* op) { CHECK(is_one(op->predicate)); Type t = op->value.type(); diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 8be6af32ce51..b7d091b3921b 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -132,7 +132,6 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const Ramp* op) override; llvm::Value* VisitExpr_(const Shuffle* op) override; llvm::Value* VisitExpr_(const Broadcast* op) override; - llvm::Value* VisitExpr_(const AssertLowerBound* op) override; // stmt void VisitStmt_(const Store* op) override; void VisitStmt_(const For* op) override; diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 842f9d1b0000..51b355e81df3 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -103,7 +103,6 @@ class AttrFunctor { virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::AssertLowerBound* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. @@ -139,7 +138,6 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(Cast); ATTR_FUNCTOR_DISPATCH(Call); ATTR_FUNCTOR_DISPATCH(Select); - ATTR_FUNCTOR_DISPATCH(AssertLowerBound); return vtable; } }; diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 220d4378cc97..8feb13843429 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include // Centralized header for constant folders. #include "../arithmetic/const_fold.h" @@ -567,4 +568,17 @@ Expr trunc(Expr x) { return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic); } +Expr assert_bound(Expr value, Expr lower, Expr upper) { + if (is_const(value) || value.as()) { + return value; + } + Expr lb = lower.defined() ? lower : IntImm::make(Int(64), arith::ConstIntBound::kNegInf); + Expr ub = upper.defined() ? upper : IntImm::make(Int(64), arith::ConstIntBound::kPosInf); + return ir::Call::make( + value.type(), + ir::intrinsic::tvm_assert_bound, + {value, lb, ub}, + ir::Call::PureIntrinsic); +} + } // namespace tvm diff --git a/src/lang/ir.cc b/src/lang/ir.cc index e55fb8bf8efc..bb8401dae843 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -176,18 +176,6 @@ Expr Let::make(Var var, Expr value, Expr body) { return Expr(node); } -Expr AssertLowerBound::make(Expr value, Expr bound) { - CHECK(value.defined()); - CHECK(bound.defined()); - CHECK_EQ(value.type(), bound.type()); - - NodePtr node = make_node(); - node->type = value.type(); - node->value = std::move(value); - node->bound = std::move(bound); - return Expr(node); -} - const char* Call::vectorizable_intrinsics[] = { "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", "log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right, @@ -847,16 +835,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& node, IRPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "assert_lower_bound("; - p->Print(op->value); - p->stream << ", "; - p->Print(op->bound); - p->stream << ")"; -}); - TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter* p) { auto* op = static_cast(node.get()); @@ -1225,7 +1203,6 @@ TVM_REGISTER_NODE_TYPE(Shuffle); TVM_REGISTER_NODE_TYPE(Prefetch); TVM_REGISTER_NODE_TYPE(Call); TVM_REGISTER_NODE_TYPE(Let); -TVM_REGISTER_NODE_TYPE(AssertLowerBound); TVM_REGISTER_NODE_TYPE(LetStmt); TVM_REGISTER_NODE_TYPE(AssertStmt); TVM_REGISTER_NODE_TYPE(ProducerConsumer); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index bec81d9ea890..0c081eed1705 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -49,8 +49,10 @@ void BinderAddAssert(Expr cond, const Expr GetVariable(const Expr& expr) { if (expr.as()) { return expr; - } else if (const auto* v = expr.as()) { - return GetVariable(v->value); + } else if (const auto* call = expr.as()) { + if (call->is_intrinsic(intrinsic::tvm_assert_bound)) { + return GetVariable(call->args[0]); + } } return Expr(); } diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 3f3b07475f9b..cb859d07f07b 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -246,10 +246,6 @@ class IRDeepCompare : if (CompareArray(op->source, rhs->source) != 0) return; } - void VisitExpr_(const AssertLowerBound *op, const Expr& other) final { - VisitExpr(op->value, other); - } - void VisitExpr_(const IntImm *op, const Expr& other) final { CompareValue(op->value, other.as()->value); } diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index c6683fb62822..f79a1ab8fe3b 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -487,11 +487,6 @@ Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const AssertLowerBound *op, const Expr& e) { - Expr value = this->Mutate(op->value); - return AssertLowerBound::make(value, 0); -} - #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \ return e; \ @@ -534,7 +529,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(UIntImm) .DISPATCH_TO_MUTATE_EXPR(FloatImm) .DISPATCH_TO_MUTATE_EXPR(StringImm) -.DISPATCH_TO_MUTATE_EXPR(AssertLowerBound) .DISPATCH_TO_MUTATE_EXPR(Shuffle); } // namespace ir diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index deabe4ae0b4e..204c0f75fe4a 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -127,10 +127,6 @@ void IRVisitor::Visit_(const Call *op) { VisitArray(op->args, this); } -void IRVisitor::Visit_(const AssertLowerBound *op) { - this->Visit(op->value); -} - #define DEFINE_BINOP_VISIT_(OP) \ void IRVisitor::Visit_(const OP* op) { \ this->Visit(op->a); \ @@ -281,7 +277,6 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Ramp) .DISPATCH_TO_VISIT(Shuffle) .DISPATCH_TO_VISIT(Broadcast) -.DISPATCH_TO_VISIT(AssertLowerBound) .DISPATCH_TO_VISIT(AssertStmt) .DISPATCH_TO_VISIT(ProducerConsumer) .DISPATCH_TO_VISIT(Provide) diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index d4818b47d1ea..c2a2fe6f5942 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -284,7 +284,6 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { auto n = make_node(*f.operator->()); n->body = LowerIntrinStmt(n->body, target); -// LOG(INFO) << "after lower intrin " << n->body; return LoweredFunc(n); } diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc new file mode 100644 index 000000000000..f7e5fbce5f2d --- /dev/null +++ b/src/pass/remove_intrin.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Remove intrinsic calls when possible. + * \file remove_intrin.cc + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +class IntrinRemover : public IRMutator { + public: + Expr Mutate_(const Call* op, const Expr& e) final { + if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + return op->args[0]; // simply return the value + } + return IRMutator::Mutate_(op, e); + } +}; + +Stmt RemoveIntrinStmt(Stmt stmt) { + return IntrinRemover().Mutate(stmt); +} + +LoweredFunc RemoveIntrin(LoweredFunc f) { + auto n = make_node(*f.operator->()); + n->body = RemoveIntrinStmt(n->body); + LOG(INFO) << "after remove intrin " << n->body; + return LoweredFunc(n); +} + +// Register the api only for test purposes +TVM_REGISTER_API("ir_pass._RemoveIntrinStmt") +.set_body_typed(RemoveIntrinStmt); + +} // namespace ir +} // namespace tvm + diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index c57f4a1109ec..0ddeec8e309a 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -187,6 +187,18 @@ def test_if_then_else(): raise ValueError('Unknown combinations') +def test_assert_bound(): + for dtype in ["int32", "int64"]: + var = tvm.var("var", dtype=dtype) + out = tvm.assert_bound(var, lower=0) + out = tvm.ir_pass._LowerIntrinStmt( + tvm.stmt.Evaluate( + tvm.floordiv(out, tvm.const(127, dtype)) + ), "c") + out = tvm.ir_pass._RemoveIntrinStmt(out) + assert tvm.ir_pass.Equal(out, tvm.stmt.Evaluate(tvm.truncdiv(var, 127))) + + if __name__ == "__main__": test_const_fold() test_const_fold2() @@ -194,3 +206,4 @@ def test_if_then_else(): test_const_fold4() test_binary_dtype_match() test_if_then_else() + test_assert_bound() From 105a0777f8dfe7d3d3e52f495d4681df25e88c0f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 17 Dec 2019 14:54:46 -0800 Subject: [PATCH 06/24] simplify nested assert_bound, deal with assert_bound in InternalVal --- src/arithmetic/int_set.cc | 14 ++++++++++++++ src/arithmetic/rewrite_simplify.cc | 7 +++++++ src/lang/expr_operator.cc | 9 +++++---- src/schedule/bound.cc | 2 +- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 9f8effb6c612..167ec9738f4e 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -505,6 +505,20 @@ class IntervalSetEvaluator : return Union(analyzer_, false_set, true_set); } + IntervalSet VisitExpr_(const Call* op) final { + if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr expr = GetRef(op); + Expr value = op->args[0]; + Expr lb = op->args[1]; + Expr ub = op->args[2]; + // TODO(yizhi): remove following "hack" + lb = lb.same_as(value) ? expr : lb; + ub = ub.same_as(value) ? expr : ub; + return IntervalSet(lb, ub); + } + return VisitExprDefault_(op); + } + IntervalSet VisitExprDefault_(const Node* op) final { DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index b26f8335055a..1265f0196b5e 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1715,6 +1715,13 @@ Mutate_(const Call* op, const Expr& self) { // the operator overload will eagerly constant fold. return op->args[0] & op->args[1]; } + } else if (op->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr value = this->Mutate(op->args[0]); + if (const Call* v = value.as()) { + if (v->is_intrinsic(intrinsic::tvm_assert_bound)) { + return value; + } + } } if (op->is_intrinsic(Call::likely)) { for (const auto& constraint : literal_constraints_) { diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 8feb13843429..4316b56ed64e 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -207,7 +207,6 @@ Expr operator%(Expr a, Expr b) { return truncmod(a, b); } -// TODO(tqchen): switch to floordiv Expr indexdiv(Expr a, Expr b) { return floordiv(a, b); } @@ -569,11 +568,13 @@ Expr trunc(Expr x) { } Expr assert_bound(Expr value, Expr lower, Expr upper) { - if (is_const(value) || value.as()) { + if (!value.as()) { + return value; + } else if (!lower.defined() && !upper.defined()) { return value; } - Expr lb = lower.defined() ? lower : IntImm::make(Int(64), arith::ConstIntBound::kNegInf); - Expr ub = upper.defined() ? upper : IntImm::make(Int(64), arith::ConstIntBound::kPosInf); + Expr lb = lower.defined() ? lower : value; + Expr ub = upper.defined() ? upper : value; return ir::Call::make( value.type(), ir::intrinsic::tvm_assert_bound, diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index e213df5e659d..9d080f9333ce 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -238,7 +238,7 @@ Map InferBound(const Schedule& sch) { InferRootBound(stage, ctx, &ret); // bind bound of root iter vars. - for (auto iv : stage->op->root_iter_vars()) { + for (auto iv : stage->op->root_iter_vars()) { auto it = ret.find(iv); if (it != ret.end()) { analyzer.Bind(iv->var, it->second); From 81386e231c7b13fbe2ed42171afd348b5f00c936 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 17 Dec 2019 17:09:42 -0800 Subject: [PATCH 07/24] add test case for assert_bound rewrite simplify --- src/arithmetic/int_set.cc | 6 +++++- tests/python/unittest/test_arith_rewrite_simplify.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 167ec9738f4e..2b47e81a2162 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -511,7 +511,11 @@ class IntervalSetEvaluator : Expr value = op->args[0]; Expr lb = op->args[1]; Expr ub = op->args[2]; - // TODO(yizhi): remove following "hack" + // keep the assert_bound intrinsic in the interval, + // e.g., interval of assert_bound(n, 0, n) is [0, assert_bound(n, 0, n)] + // this makes sure variable n NEVER escape the assert_bound CallNode and appear standalone, + // it simplifies the rewrite simplification rules, + // e.g., no need to write things like TVM_TRY_REWRITE((x + y) - assert_bound(x, b1, b2), y) lb = lb.same_as(value) ? expr : lb; ub = ub.same_as(value) ? expr : ub; return IntervalSet(lb, ub); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 99c2942cd470..c7c0492911d6 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -833,6 +833,11 @@ def test_cast_simplify(): for i in [0, 1, 2, 3]: ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1)) +def test_assert_bound_simplify(): + ck = RewriteChecker() + x = tvm.var("x") + ck.verify(tvm.assert_bound(x, 0) + 1 >= 1, tvm.const(True, "bool")) + if __name__ == "__main__": test_floordiv_index_simplify() test_floormod_index_simplify() @@ -849,3 +854,4 @@ def test_cast_simplify(): test_logical_simplify() test_let_simplify() test_cast_simplify() + test_assert_bound_simplify() From 55bd45e58d1467f7c21056a4b1c78ea842fc6601 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 21 Dec 2019 23:05:08 -0800 Subject: [PATCH 08/24] fix deduce bound accordingly --- src/arithmetic/bound_deducer.cc | 51 +++++++++++++++++++ tests/python/unittest/test_lang_tensor.py | 10 +++- .../unittest/test_pass_inject_copy_intrin.py | 9 ++-- .../unittest/test_pass_loop_partition.py | 1 + .../unittest/test_schedule_bound_inference.py | 13 +++-- tests/python/unittest/util.py | 31 +++++++++++ 6 files changed, 106 insertions(+), 9 deletions(-) create mode 100644 tests/python/unittest/util.py diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 31fedcc72cde..e936b954850a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -68,6 +68,43 @@ std::vector GetPath(Expr target, Expr expr) { return v.path_; } +class BoundRemover : public IRMutator { + public: + Expr Do(Expr e) { + remove_bounded_ = true; + return Mutate(ir::Simplify(e)); + } + + Expr Undo(Expr e) { + CHECK(remove_bounded_) << "Call Do(expr) first."; + remove_bounded_ = false; + return Mutate(e); + } + + Expr Mutate_(const Call* op, const Expr& e) final { + if (op->is_intrinsic(intrinsic::tvm_assert_bound) && remove_bounded_) { + // TODO: deal with recursive assert_bound + Expr value = op->args[0]; + const Variable* var = value.as(); + CHECK(var) << "Invalid value in " << e << ". It should have been simplified."; + bounded_var_map_[var] = GetRef(op); + return value; + } + return IRMutator::Mutate_(op, e); + } + + Expr Mutate_(const Variable* op, const Expr& e) final { + if (!remove_bounded_ && bounded_var_map_.count(op)) { + return bounded_var_map_[op]; + } + return e; + } + + private: + bool remove_bounded_ = false; + std::unordered_map bounded_var_map_; +}; + enum CompareOp {kGreater, kLess, kEqual}; // a visitor to deduce the bound of a variable from a expression @@ -295,6 +332,16 @@ void BoundDeducer::Transform() { void BoundDeducer::Deduce() { Init(); if (!success_) return; + + // Any variable appears in both expr and result, + // they should not be eagerly simplified according to its bound + // e.g., i + n/4 >= n + // => i >= n - n/4 + // Thus we remove assert_bound here and reset later. + BoundRemover ra, rb; + expr_ = ra.Do(expr_); + result_ = rb.Do(result_); + Relax(); if (!success_) return; // get the path @@ -306,9 +353,13 @@ void BoundDeducer::Deduce() { expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); + + expr_ = ra.Undo(expr_); + result_ = rb.Undo(result_); } void BoundDeducer::Relax() { + IntSet a = EvalSet(expr_, relax_map_); IntSet b = EvalSet(result_, relax_map_); if (a.is_everything() || b.is_everything()) { diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 44aca3b324bb..02124368e7f1 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -16,6 +16,7 @@ # under the License. import tvm from topi.nn.pooling import pool +from util import check_assert_bound def test_tensor(): m = tvm.var('m') @@ -26,7 +27,10 @@ def test_tensor(): T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) print(T) print(T.op.body) - assert(tuple(T.shape) == (m, n, l)) + assert(len(T.shape) == 3) + check_assert_bound(T.shape[0], m, 0, m) + check_assert_bound(T.shape[1], n, 0, n) + check_assert_bound(T.shape[2], l, 0, l) assert(isinstance(A.op, tvm.tensor.PlaceholderOp)) assert(A == A) assert(T.op.output(0) == T) @@ -182,7 +186,9 @@ def test_tensor_scan(): res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), tvm.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]), s) - assert tuple(res.shape) == (m, n) + assert len(res.shape) == 2 + check_assert_bound(res.shape[0], m, 0, m) + check_assert_bound(res.shape[1], n, 0, n) def test_scan_multi_out(): m = tvm.var("m") diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index 858b1e8a9153..d2cd39d97818 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +from util import check_assert_bound def test_copy2d(): m = tvm.var('m') @@ -29,10 +30,12 @@ def test_copy2d(): Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) def cb(src, dst, pad_before, pad_after, pad_value): - assert dst.strides[0] == l + check_assert_bound(dst.strides[0], l, 0, l) assert dst.strides[1].value == 1 - assert src.strides[0] == l - assert tuple(src.shape) == (m, l) + check_assert_bound(src.strides[0], l, 0, l) + assert len(src.shape) == 2 + check_assert_bound(src.shape[0], m, 0, m) + check_assert_bound(src.shape[1], l, 0, l) return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 021709506754..95f2264dbc49 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -65,6 +65,7 @@ def test_basic(): stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) assert('if' not in str(stmt.body.body.body.first)) + assert('if' in str(stmt.body.body.body.rest)) def test_const_loop(): n = 21 diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 9c3d1df17f2b..e62ca3e3e9cb 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +from util import check_assert_bound def test_bound1(): m = tvm.var('m') @@ -112,6 +113,7 @@ def test_bound_fusesplit1(): bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) idxdiv = tvm.indexdiv + print(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l))) assert(tvm.ir_pass.Simplify( bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0) @@ -179,7 +181,10 @@ def test_bound_scan(): s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_scan = tvm.scan(s_init, s_update, s_state) - assert tuple(s_scan.shape) == (m, n) + assert len(s_scan.shape) == 2 + check_assert_bound(s_scan.shape[0], m, 0, m) + check_assert_bound(s_scan.shape[1], n, 0, n) + s = tvm.create_schedule(s_scan.op) XX = s.cache_read(X, "local", s_update) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) @@ -247,7 +252,7 @@ def test_bound_group_schedule(): s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 - assert bounds[x.op.axis[1]].extent == n + check_assert_bound(bounds[x.op.axis[1]].extent, n, 0, n) def test_bound_nest_group(): m = tvm.var("m") @@ -267,7 +272,7 @@ def test_bound_nest_group(): assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent.value == 1 assert bounds[x1.op.axis[0]].extent.value == 1 - assert bounds[x1.op.axis[1]].extent == n + check_assert_bound(bounds[x1.op.axis[1]].extent, n, 0, n) def test_bound_nest_thread(): @@ -294,7 +299,7 @@ def test_bound_nest_thread(): bounds = tvm.schedule.InferBound(s) assert(bounds[A1.op.axis[0]].extent.value==1) assert(bounds[A2.op.axis[0]].extent.value==32) - assert(bounds[A3.op.axis[0]].extent == m) + check_assert_bound(bounds[A3.op.axis[0]].extent, m, 0, m) def test_gemm_bound(): nn = 1024 diff --git a/tests/python/unittest/util.py b/tests/python/unittest/util.py new file mode 100644 index 000000000000..7734654be396 --- /dev/null +++ b/tests/python/unittest/util.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from topi.util import get_const_int + + +def check_assert_bound(expr, var, lb, ub): + assert isinstance(expr, tvm.expr.Call) + assert expr.name == "tvm_assert_bound" + assert expr.dtype == var.dtype + assert expr.args[0] == var + lower = get_const_int(expr.args[1]) if isinstance(expr.args[1], (tvm.expr.IntImm, tvm.expr.UIntImm)) \ + else expr.args[1] + upper = get_const_int(expr.args[2]) if isinstance(expr.args[2], (tvm.expr.IntImm, tvm.expr.UIntImm)) \ + else expr.args[2] + assert lower == lb + assert upper == ub From c405d7b4b25b6a826f70183a758179473320b0dc Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 23 Dec 2019 08:20:36 -0800 Subject: [PATCH 09/24] fix floordiv IntervalSetEvaluator when b \in [0, n] --- python/tvm/build_module.py | 2 +- src/arithmetic/int_set.cc | 7 ++++++- src/pass/remove_intrin.cc | 7 +++++++ tests/python/unittest/test_lang_schedule.py | 12 +++++++----- .../python/unittest/test_schedule_bound_inference.py | 9 +++++---- 5 files changed, 26 insertions(+), 11 deletions(-) diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 02569c949f7a..00aef417a777 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -500,7 +500,7 @@ def _build_for_device(flist, target, target_host): fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] - fhost = [ir_pass.RemoveIntrin(x, target_host.target_name) for x in fhost] + fhost = [ir_pass.RemoveIntrin(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] mdev = codegen.build_module(fdevice, str(target)) if fdevice else None diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 2b47e81a2162..61be34df6139 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -424,7 +424,12 @@ class IntervalSetEvaluator : } IntervalSet VisitExpr_(const FloorDiv* op) final { - return VisitBinaryExpr_(op); + IntervalSet a = this->Eval(op->a); + IntervalSet b = this->Eval(op->b); + if (MatchPoint(a, op->a) && (b->min_value.same_as(op->b) || b->max_value.same_as(op->b))) { + return IntervalSet::SinglePoint(GetRef(op)); + } + return Combine(analyzer_, a, b); } IntervalSet VisitExpr_(const FloorMod* op) final { diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc index f7e5fbce5f2d..659e502071e2 100644 --- a/src/pass/remove_intrin.cc +++ b/src/pass/remove_intrin.cc @@ -44,6 +44,10 @@ Stmt RemoveIntrinStmt(Stmt stmt) { return IntrinRemover().Mutate(stmt); } +Expr RemoveIntrinExpr(Expr expr) { + return IntrinRemover().Mutate(expr); +} + LoweredFunc RemoveIntrin(LoweredFunc f) { auto n = make_node(*f.operator->()); n->body = RemoveIntrinStmt(n->body); @@ -55,6 +59,9 @@ LoweredFunc RemoveIntrin(LoweredFunc f) { TVM_REGISTER_API("ir_pass._RemoveIntrinStmt") .set_body_typed(RemoveIntrinStmt); +TVM_REGISTER_API("ir_pass._RemoveIntrinExpr") +.set_body_typed(RemoveIntrinExpr); + } // namespace ir } // namespace tvm diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 0a653066bff7..38d5ee951a9a 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -17,6 +17,7 @@ import pytest import tvm import pickle as pkl +from util import check_assert_bound def test_schedule_create(): m = tvm.var('m') @@ -164,7 +165,8 @@ def test_rfactor(): # normal schedule s = tvm.create_schedule(B.op) BF = s.rfactor(B, k1) - assert(tuple(BF.shape) == (n, n)) + assert(BF.shape[0] == n) + check_assert_bound(BF.shape[1], n, 0, n) assert(set(BF.op.body[0].axis) == set([k2])) assert(s[B].op.body[0].axis[0].dom.extent == n) assert(len(s[B].all_iter_vars) == 2) @@ -174,7 +176,7 @@ def test_rfactor(): xo, xi = s[B].split(B.op.axis[0], factor=8) BF = s.rfactor(B, ki) assert(BF.shape[0].value == 4) - assert(BF.shape[1] == n) + check_assert_bound(BF.shape[1], n, 0, n) assert(BF.op.body[0].axis[0] == k2) assert(BF.op.body[0].axis[1].var == ko.var) assert(s[B].op.body[0].axis[0].dom.extent.value == 4) @@ -183,7 +185,7 @@ def test_rfactor(): ko, ki = s[B].split(k1, factor=4) xo, xi = s[B].split(B.op.axis[0], factor=8) BF = s.rfactor(B, ki, 1) - assert(n == BF.shape[0]) + check_assert_bound(BF.shape[0], n, 0, n) assert(BF.shape[1].value == 4) assert(BF.op.body[0].axis[0] == k2) assert(BF.op.body[0].axis[1].var == ko.var) @@ -222,7 +224,7 @@ def test_tensor_intrin_scalar_params(): def intrin_func(ins, outs, sp): assert(isinstance(ins[0], tvm.schedule.Buffer)) - assert(ins[0].shape[0] == n) + check_assert_bound(ins[0].shape[0], n, 0, n) assert(sp[0] == v) assert(sp[1] == w) return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) @@ -232,7 +234,7 @@ def intrin_func(ins, outs, sp): assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) - assert(intrin.buffers[0].shape[0] == n) + check_assert_bound(intrin.buffers[0].shape[0], n, 0, n) assert tuple(intrin.scalar_params) == tuple((v, w)) A = tvm.placeholder((10,10), name='A') diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index e62ca3e3e9cb..f49498f27a89 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -113,20 +113,21 @@ def test_bound_fusesplit1(): bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) idxdiv = tvm.indexdiv - print(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l))) assert(tvm.ir_pass.Simplify( - bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0) + tvm.ir_pass._RemoveIntrinExpr(bounds[A1.op.axis[0]].min) - idxdiv(xo * split1, l)).value == 0) expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1) + actual_extent = tvm.ir_pass._RemoveIntrinExpr(bounds[A1.op.axis[0]].extent) for i in range(1, 6): for j in range(1, 6): for k in range(1, 6): vars = tvm.convert({split1: tvm.const(i, "int32"), l: tvm.const(j, "int32"), xo.var: tvm.const(k, "int32")}) - comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value + comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(actual_extent, vars)).value exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value assert(comp_ext == exp_ext) - assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) + l_extent = tvm.ir_pass._RemoveIntrinExpr(bounds[A1.op.axis[1]].extent) + assert(tvm.ir_pass.Simplify(l_extent - l).value == 0) def test_bound_fusesplit2(): m = tvm.var("m") From e3639b9926770e4665fff36d6bb50fbe917c5b68 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 23 Dec 2019 09:02:45 -0800 Subject: [PATCH 10/24] fix lint --- src/arithmetic/bound_deducer.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 9a59239ce79a..8633ce66bd17 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -83,7 +83,6 @@ class BoundRemover : public IRMutator { Expr Mutate_(const Call* op, const Expr& e) final { if (op->is_intrinsic(intrinsic::tvm_assert_bound) && remove_bounded_) { - // TODO: deal with recursive assert_bound Expr value = op->args[0]; const Variable* var = value.as(); CHECK(var) << "Invalid value in " << e << ". It should have been simplified."; @@ -359,7 +358,6 @@ void BoundDeducer::Deduce() { } void BoundDeducer::Relax() { - IntSet a = EvalSet(expr_, relax_map_); IntSet b = EvalSet(result_, relax_map_); if (a.is_everything() || b.is_everything()) { From 9c94b5d587015a2fba78b9a80dfafd8ec82a19a1 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 23 Dec 2019 12:00:16 -0800 Subject: [PATCH 11/24] fix compile error --- src/lang/expr_operator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 945ff40ca67c..efc175233180 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -635,7 +635,7 @@ Expr assert_bound(Expr value, Expr lower, Expr upper) { Expr lb = lower.defined() ? lower : value; Expr ub = upper.defined() ? upper : value; return ir::Call::make( - value.type(), + value.dtype(), ir::intrinsic::tvm_assert_bound, {value, lb, ub}, ir::Call::PureIntrinsic); From 38fb4d36f594fd714e88c6ab073b08424898b80a Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 26 Dec 2019 20:47:09 -0800 Subject: [PATCH 12/24] add assert_bound in hybrid script --- python/tvm/hybrid/calls.py | 12 +++ python/tvm/hybrid/preprocessor.py | 2 +- python/tvm/hybrid/runtime.py | 81 ++++++++++++------- src/pass/remove_intrin.cc | 1 - tests/python/unittest/test_hybrid_script.py | 5 +- .../unittest/test_schedule_schedule_ops.py | 6 +- 6 files changed, 73 insertions(+), 34 deletions(-) diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 414822068e07..d27c45e7b83d 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -159,3 +159,15 @@ def max_num_threads(func_id, args): _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") res = _tgt.current_target(args[0].value).max_num_threads return _api.convert(res) + +def tvm_assert_bound(func_id, args): + n = args.__len__() + _internal_assert(func_id == "tvm_assert_bound", "This function cannot be directly invoked!") + _internal_assert(n >= 1, "At least 1 argument should be provided.") + _internal_assert(n <= 3, "Accept at most 3 arguments.") + if n == 1: + return _make._OpAssertBound(args[0], None, None) + elif n == 2: + return _make._OpAssertBound(args[0], args[1], None) + return _make._OpAssertBound(*args) + diff --git a/python/tvm/hybrid/preprocessor.py b/python/tvm/hybrid/preprocessor.py index 1a9de4e3f801..035e8a40f245 100644 --- a/python/tvm/hybrid/preprocessor.py +++ b/python/tvm/hybrid/preprocessor.py @@ -63,7 +63,7 @@ def visit_Call(self, node): _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \ ['range', 'max', 'min', 'len'] + \ list(self.symbols.keys()), \ - "Function call id not in intrinsics' list") + "Function call id " + func_id + " not in intrinsics' list") for elem in node.args: self.visit(elem) diff --git a/python/tvm/hybrid/runtime.py b/python/tvm/hybrid/runtime.py index aa00b4b80251..f02bf2cb5593 100644 --- a/python/tvm/hybrid/runtime.py +++ b/python/tvm/hybrid/runtime.py @@ -110,36 +110,59 @@ def max_num_threads(allow_none=True): return target.current_target(allow_none).max_num_threads +def tvm_assert_bound(value, lb=None, ub=None): + """ + Provide lower bound and upper bound for the value. + For now we simply return the value + + Parameters + ---------- + value: Expr + The bounded value + lb: Expr + lower bound (inclusive) + ub: Expr + upper bound (inclusive) + + Returns + ------- + res: Expr + same as value + """ + return value + + HYBRID_GLOBALS = { - 'unroll' : range, - 'vectorize' : range, - 'parallel' : range, - 'const_range' : range, - 'bind' : bind, - 'allocate' : allocate, - 'output_tensor' : allocate, - 'sqrt' : numpy.sqrt, - 'rsqrt' : rsqrt, - 'log' : numpy.log, - 'tanh' : numpy.tanh, - 'power' : numpy.power, - 'exp' : numpy.exp, - 'sigmoid' : sigmoid, - 'popcount' : popcount, - 'likely' : lambda cond: cond, - 'uint8' : numpy.uint8, - 'uint16' : numpy.uint16, - 'uint32' : numpy.uint32, - 'uint64' : numpy.uint64, - 'int8' : numpy.int8, - 'int16' : numpy.int16, - 'int32' : numpy.int32, - 'int64' : numpy.int64, - 'float16' : numpy.float16, - 'float32' : numpy.float32, - 'float64' : numpy.float64, - 'ceil_div' : lambda a, b: (a + b - 1) // b, - 'max_num_threads': max_num_threads + 'unroll' : range, + 'vectorize' : range, + 'parallel' : range, + 'const_range' : range, + 'bind' : bind, + 'allocate' : allocate, + 'output_tensor' : allocate, + 'sqrt' : numpy.sqrt, + 'rsqrt' : rsqrt, + 'log' : numpy.log, + 'tanh' : numpy.tanh, + 'power' : numpy.power, + 'exp' : numpy.exp, + 'sigmoid' : sigmoid, + 'popcount' : popcount, + 'likely' : lambda cond: cond, + 'uint8' : numpy.uint8, + 'uint16' : numpy.uint16, + 'uint32' : numpy.uint32, + 'uint64' : numpy.uint64, + 'int8' : numpy.int8, + 'int16' : numpy.int16, + 'int32' : numpy.int32, + 'int64' : numpy.int64, + 'float16' : numpy.float16, + 'float32' : numpy.float32, + 'float64' : numpy.float64, + 'ceil_div' : lambda a, b: (a + b - 1) // b, + 'max_num_threads' : max_num_threads, + 'tvm_assert_bound' : tvm_assert_bound } diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc index 659e502071e2..1553df72f212 100644 --- a/src/pass/remove_intrin.cc +++ b/src/pass/remove_intrin.cc @@ -51,7 +51,6 @@ Expr RemoveIntrinExpr(Expr expr) { LoweredFunc RemoveIntrin(LoweredFunc f) { auto n = make_node(*f.operator->()); n->body = RemoveIntrinStmt(n->body); - LOG(INFO) << "after remove intrin " << n->body; return LoweredFunc(n); } diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 1f101a1e92e8..94556867e3ad 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -22,9 +22,10 @@ @pytest.mark.skip def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): + val = tvm.ir_pass._RemoveIntrinExpr(val) val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) + assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)), val return val.value ctx = tvm.context(target, 0) @@ -180,7 +181,7 @@ def fanout(n, a): assert isinstance(ir, tvm.stmt.For) assert ir.loop_var.name == 'i' assert ir.min.value == 0 - assert tvm.ir_pass.Equal(ir.extent, n - 3) + assert tvm.ir_pass.Equal(tvm.ir_pass._RemoveIntrinExpr(ir.extent), n - 3) #Check loopbody ibody = ir.body assert isinstance(ibody, tvm.stmt.AttrStmt) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 5275aec4db90..42619b102289 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -16,6 +16,7 @@ # under the License. import tvm import numpy as np +from util import check_assert_bound def test_schedule0(): m = tvm.var('m') @@ -67,7 +68,10 @@ def test_schedule_scan(): s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i]) res = tvm.scan(s_init, s_update, s_state) - assert tuple(res.shape) == (m, n) + assert len(res.shape) == 2 + check_assert_bound(res.shape[0], m, 0, m) + check_assert_bound(res.shape[1], n, 0, n) + s = tvm.create_schedule(res.op) s = s.normalize() ir = tvm.lower(s, [s_state], simple_mode=True) From d078c839c203e6603b6ec22a771f8851a7e63632 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 26 Dec 2019 20:56:49 -0800 Subject: [PATCH 13/24] fix lint --- python/tvm/hybrid/calls.py | 1 - python/tvm/hybrid/runtime.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index d27c45e7b83d..012da179b36e 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -170,4 +170,3 @@ def tvm_assert_bound(func_id, args): elif n == 2: return _make._OpAssertBound(args[0], args[1], None) return _make._OpAssertBound(*args) - diff --git a/python/tvm/hybrid/runtime.py b/python/tvm/hybrid/runtime.py index f02bf2cb5593..427ddf0779e7 100644 --- a/python/tvm/hybrid/runtime.py +++ b/python/tvm/hybrid/runtime.py @@ -110,7 +110,7 @@ def max_num_threads(allow_none=True): return target.current_target(allow_none).max_num_threads -def tvm_assert_bound(value, lb=None, ub=None): +def tvm_assert_bound(value, lower=None, upper=None): #pylint: disable=unused-argument """ Provide lower bound and upper bound for the value. For now we simply return the value @@ -119,9 +119,9 @@ def tvm_assert_bound(value, lb=None, ub=None): ---------- value: Expr The bounded value - lb: Expr + lower: Expr lower bound (inclusive) - ub: Expr + upper: Expr upper bound (inclusive) Returns From ac26b2792b572a420985d47d2dc54848bfcf37a9 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 27 Dec 2019 00:18:33 -0800 Subject: [PATCH 14/24] fix auto buffer bind for assert_bound --- python/tvm/build_module.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 00aef417a777..be3ab2d16f4d 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -292,9 +292,16 @@ def get_binds(args, compact=False, binds=None): binds = {} if binds is None else binds.copy() cfg = current_build_config() arg_list = [] + + def is_var(idx): + if isinstance(idx, expr.Var) or \ + (isinstance(idx, expr.Call) and idx.name == "tvm_assert_bound"): + return True + return False + for x in args: if isinstance(x, tensor.Tensor): - any_dim = any(isinstance(i, expr.Var) for i in x.shape) + any_dim = any(is_var(i) for i in x.shape) buffer_type = "auto_broadcast" if any_dim and not compact else "" if x not in binds: buf = api.decl_buffer(x.shape, From 35d71b8827704687aa8eb4570da16f1d13ee6603 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 27 Dec 2019 11:07:23 -0800 Subject: [PATCH 15/24] debug ci --- src/pass/remove_intrin.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc index 1553df72f212..2272a6d75484 100644 --- a/src/pass/remove_intrin.cc +++ b/src/pass/remove_intrin.cc @@ -50,6 +50,7 @@ Expr RemoveIntrinExpr(Expr expr) { LoweredFunc RemoveIntrin(LoweredFunc f) { auto n = make_node(*f.operator->()); + LOG(INFO) << n->body; n->body = RemoveIntrinStmt(n->body); return LoweredFunc(n); } From 82adeab5ef028187a61d8fc86fd82d5c86907cd0 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 27 Dec 2019 12:13:57 -0800 Subject: [PATCH 16/24] revoke --- src/pass/remove_intrin.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc index 2272a6d75484..1553df72f212 100644 --- a/src/pass/remove_intrin.cc +++ b/src/pass/remove_intrin.cc @@ -50,7 +50,6 @@ Expr RemoveIntrinExpr(Expr expr) { LoweredFunc RemoveIntrin(LoweredFunc f) { auto n = make_node(*f.operator->()); - LOG(INFO) << n->body; n->body = RemoveIntrinStmt(n->body); return LoweredFunc(n); } From cfecf9d0ab0cc7146c50b7ca9e43dee75605cecc Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 27 Dec 2019 13:25:28 -0800 Subject: [PATCH 17/24] retrigger --- src/pass/remove_intrin.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc index 1553df72f212..2272a6d75484 100644 --- a/src/pass/remove_intrin.cc +++ b/src/pass/remove_intrin.cc @@ -50,6 +50,7 @@ Expr RemoveIntrinExpr(Expr expr) { LoweredFunc RemoveIntrin(LoweredFunc f) { auto n = make_node(*f.operator->()); + LOG(INFO) << n->body; n->body = RemoveIntrinStmt(n->body); return LoweredFunc(n); } From 10bc34988e6962d12257155438701e1f00cb759f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 27 Dec 2019 20:44:39 -0800 Subject: [PATCH 18/24] fix out of bound in path_ visit --- src/arithmetic/bound_deducer.cc | 6 +++--- src/pass/remove_intrin.cc | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 8633ce66bd17..fe361191a9be 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -72,13 +72,13 @@ class BoundRemover : public IRMutator { public: Expr Do(Expr e) { remove_bounded_ = true; - return Mutate(ir::Simplify(e)); + return IRMutator::Mutate(ir::Simplify(e)); } Expr Undo(Expr e) { CHECK(remove_bounded_) << "Call Do(expr) first."; remove_bounded_ = false; - return Mutate(e); + return IRMutator::Mutate(e); } Expr Mutate_(const Call* op, const Expr& e) final { @@ -120,7 +120,7 @@ class BoundDeducer: public IRVisitor { void Visit(const NodeRef& e) final { if (!success_) return; - if (e.get() == path_[iter_++]) { + if (iter_ < path_.size() && e.get() == path_[iter_++]) { IRVisitor::Visit(e); } else { success_ = false; diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc index 2272a6d75484..1553df72f212 100644 --- a/src/pass/remove_intrin.cc +++ b/src/pass/remove_intrin.cc @@ -50,7 +50,6 @@ Expr RemoveIntrinExpr(Expr expr) { LoweredFunc RemoveIntrin(LoweredFunc f) { auto n = make_node(*f.operator->()); - LOG(INFO) << n->body; n->body = RemoveIntrinStmt(n->body); return LoweredFunc(n); } From 68ae2243877e864ec3941468aa78fcb28e6c1c77 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 27 Dec 2019 23:40:37 -0800 Subject: [PATCH 19/24] fix test_any.py --- python/tvm/build_module.py | 1 + src/arithmetic/int_set.cc | 11 ++++++++++- src/codegen/build_module.cc | 2 ++ topi/python/topi/nn/conv2d.py | 10 +++++----- topi/python/topi/util.py | 20 +++++++++++++++++++- topi/python/topi/x86/conv2d.py | 4 ++-- topi/python/topi/x86/dense.py | 28 ++++++++++++++-------------- 7 files changed, 53 insertions(+), 23 deletions(-) diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index be3ab2d16f4d..438d33403807 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -506,6 +506,7 @@ def _build_for_device(flist, target, target_host): fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] + fdevice = [ir_pass.RemoveIntrin(x) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.RemoveIntrin(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 17f2e4dfea99..fb54aac7720a 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -427,13 +427,22 @@ class IntervalSetEvaluator : IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && (b->min_value.same_as(op->b) || b->max_value.same_as(op->b))) { + // e.g., floordiv(10, [0, n]) + // if using VisitBinaryExpr_ it will be inferred as IntervalSet::Everything() return IntervalSet::SinglePoint(GetRef(op)); } return Combine(analyzer_, a, b); } IntervalSet VisitExpr_(const FloorMod* op) final { - return VisitBinaryExpr_(op); + IntervalSet a = this->Eval(op->a); + IntervalSet b = this->Eval(op->b); + if (MatchPoint(a, op->a) && (b->min_value.same_as(op->b) || b->max_value.same_as(op->b))) { + // e.g., floormod(10, [0, n]) + // if using VisitBinaryExpr_ it will be inferred as IntervalSet::Everything() + return IntervalSet::SinglePoint(GetRef(op)); + } + return Combine(analyzer_, a, b); } IntervalSet VisitExpr_(const Min* op) final { diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index ca25731cafef..42576d6701b7 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -510,6 +510,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fdevice.size(); ++i) { auto func = fdevice[i]; func = ir::LowerIntrin(func, target->target_name); + func = ir::RemoveIntrin(func); fdevice.Set(i, func); } @@ -531,6 +532,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::LowerIntrin(func, target_host->target_name); + func = ir::RemoveIntrin(func); func = ir::LowerDeviceStorageAccessInfo(func); func = ir::CombineContextCall(func); fhost.Set(i, func); diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 5af30335a9c5..bfe84983c6d7 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -140,18 +140,18 @@ def conv2d_infer_layout(workload, cfg): def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): """ Get the workload structure. """ if data_layout == 'NCHW': - _, CI, IH, IW = [x.value for x in data.shape] + _, CI, IH, IW = get_const_tuple(data.shape) elif data_layout == 'NHWC': - _, IH, IW, CI = [x.value for x in data.shape] + _, IH, IW, CI = get_const_tuple(data.shape) elif data_layout == 'HWCN': - IH, IW, CI, _ = [x.value for x in data.shape] + IH, IW, CI, _ = get_const_tuple(data.shape) else: raise ValueError("not support this layout {} yet".format(data_layout)) if data_layout == 'NCHW': - CO, CIG, KH, KW = [x.value for x in kernel.shape] + CO, CIG, KH, KW = get_const_tuple(kernel.shape) else: - KH, KW, CIG, CO = [x.value for x in kernel.shape] + KH, KW, CIG, CO = get_const_tuple(kernel.shape) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) GRPS = CI // CIG diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index e25e85dac05e..208abf28dd5f 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -143,6 +143,24 @@ def equal_const_int(expr, value): return expr.value == value +def is_var(input): + """Check whether the input is tvm.expr.Var or tvm_assert_bound intrinsic. + + Parameters + ---------- + input : tvm.Expr + The input expression. + + Returns + ------- + equal : bool + Whether it is tvm.expr.Var or + tvm_assert_bound intrinsic (which provides the boundary information of a Var). + """ + return isinstance(input, tvm.expr.Var) \ + or (isinstance(input, tvm.expr.Call) and input.name == "tvm_assert_bound") + + def get_const_tuple(in_tuple): """Verifies input tuple is IntImm or Var, returns tuple of int or Var. @@ -158,7 +176,7 @@ def get_const_tuple(in_tuple): """ ret = [] for elem in in_tuple: - if isinstance(elem, tvm.expr.Var): + if is_var(elem): ret.append(elem) elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): elem = tvm.ir_pass.Simplify(elem) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 0e284da17ee6..59fc7389260d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -30,7 +30,7 @@ conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.pad import pad -from ..util import get_const_tuple +from ..util import get_const_tuple, is_var from . import conv2d_avx_1x1, conv2d_avx_common @@ -43,7 +43,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth """ static_data_shape = [] for dim in get_const_tuple(data.shape): - if isinstance(dim, tvm.expr.Var): + if is_var(dim): static_data_shape.append(1) else: static_data_shape.append(dim) diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index b7a3d6d5a330..dcd7a657887e 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -24,7 +24,7 @@ from .util import get_fp32_len from .. import generic, tag, nn -from ..util import traverse_inline, get_const_tuple +from ..util import traverse_inline, get_const_tuple, is_var @autotvm.register_topi_compute(nn.dense, "cpu", "direct") def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): @@ -40,7 +40,7 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): # Always use dense_nopack for dynamic input. # This is a temporary for CV models. # TODO(kevinthesun): use kernel dispatcher instead. - if isinstance(M, tvm.expr.Var): + if is_var(M): return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype) # For small batch sizes, don't pack weight into cache-friendly layout @@ -59,9 +59,9 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) # batch, in_dim N, _ = get_const_tuple(weight.shape) # out_dim # create tuning space - cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=3) - cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=3) - cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2) + cfg.define_split("tile_y", 32 if is_var(M) else M, num_outputs=3) + cfg.define_split("tile_x", 32 if is_var(N) else N, num_outputs=3) + cfg.define_split("tile_k", 32 if is_var(K) else K, num_outputs=2) if cfg.is_fallback: _default_dense_pack_config(cfg, M, N, K) @@ -93,9 +93,9 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) # create tuning space - cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2) - cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2) - cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2) + cfg.define_split("tile_y", 32 if is_var(M) else M, num_outputs=2) + cfg.define_split("tile_x", 32 if is_var(N) else N, num_outputs=2) + cfg.define_split("tile_k", 32 if is_var(K) else K, num_outputs=2) if cfg.is_fallback: _default_dense_nopack_config(cfg, M, N, K) @@ -218,11 +218,11 @@ def _schedule_dense_nopack_template(cfg, s, C): def _default_dense_pack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.expr.Var): + if is_var(M): M = 16 - if isinstance(N, tvm.expr.Var): + if is_var(N): N = 16 - if isinstance(K, tvm.expr.Var): + if is_var(K): K = 16 vec_width = get_fp32_len() @@ -255,11 +255,11 @@ def _default_dense_pack_config(cfg, M, N, K): def _default_dense_nopack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.expr.Var): + if is_var(M): M = 16 - if isinstance(N, tvm.expr.Var): + if is_var(N): N = 16 - if isinstance(K, tvm.expr.Var): + if is_var(K): K = 16 vec_width = get_fp32_len() From 8d16054562d076aa28ac789f2c3e0f039ec2b644 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 29 Dec 2019 14:22:36 -0800 Subject: [PATCH 20/24] fix lint --- topi/python/topi/util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 208abf28dd5f..fa24f37cc62b 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -143,12 +143,12 @@ def equal_const_int(expr, value): return expr.value == value -def is_var(input): +def is_var(expr): """Check whether the input is tvm.expr.Var or tvm_assert_bound intrinsic. Parameters ---------- - input : tvm.Expr + expr : tvm.Expr The input expression. Returns @@ -157,8 +157,8 @@ def is_var(input): Whether it is tvm.expr.Var or tvm_assert_bound intrinsic (which provides the boundary information of a Var). """ - return isinstance(input, tvm.expr.Var) \ - or (isinstance(input, tvm.expr.Call) and input.name == "tvm_assert_bound") + return isinstance(expr, tvm.expr.Var) \ + or (isinstance(expr, tvm.expr.Call) and expr.name == "tvm_assert_bound") def get_const_tuple(in_tuple): From 4f31cce1bb0aa3a49f32500fff2822135313db24 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 30 Dec 2019 11:20:14 -0800 Subject: [PATCH 21/24] fix gpu unittest --- tests/python/unittest/test_codegen_device.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 45ecf9539337..44532dcf50e8 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -50,6 +50,7 @@ def test_add_pipeline(): fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] # lower the floordiv(use stackvm rules so it works for all targets) fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits] + fsplits = [tvm.ir_pass.RemoveIntrin(x) for x in fsplits] fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) def check_target(device, host="stackvm"): From 3d15b40198093d3545ebfeb95a0053d2b1e5d1e8 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 31 Dec 2019 16:13:37 -0800 Subject: [PATCH 22/24] polish bound deducer --- src/arithmetic/bound_deducer.cc | 14 +++--- src/arithmetic/int_set.cc | 50 ++++++++++++------- .../unittest/test_arith_rewrite_simplify.py | 1 + 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 1fb02d558319..2e67f81eabc5 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -70,12 +70,12 @@ std::vector GetPath(Expr target, Expr expr) { class BoundRemover : public IRMutator { public: - Expr Do(Expr e) { + Expr Remove(Expr e) { remove_bounded_ = true; return IRMutator::Mutate(ir::Simplify(e)); } - Expr Undo(Expr e) { + Expr Reset(Expr e) { CHECK(remove_bounded_) << "Call Do(expr) first."; remove_bounded_ = false; return IRMutator::Mutate(e); @@ -336,10 +336,12 @@ void BoundDeducer::Deduce() { // they should not be eagerly simplified according to its bound // e.g., i + n/4 >= n // => i >= n - n/4 + // If we eagerly simplified the left side given assert_bound(n, 0, +inf) + // we would get i + 0 >= n => i >= n, which is obviously incorrect. // Thus we remove assert_bound here and reset later. BoundRemover ra, rb; - expr_ = ra.Do(expr_); - result_ = rb.Do(result_); + expr_ = ra.Remove(expr_); + result_ = rb.Remove(result_); Relax(); if (!success_) return; @@ -353,8 +355,8 @@ void BoundDeducer::Deduce() { Visit(expr_); - expr_ = ra.Undo(expr_); - result_ = rb.Undo(result_); + expr_ = ra.Reset(expr_); + result_ = rb.Reset(result_); } void BoundDeducer::Relax() { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 696a1cb35028..d41b8b049556 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -416,33 +416,19 @@ class IntervalSetEvaluator : } IntervalSet VisitExpr_(const Div* op) final { - return VisitBinaryExpr_(op); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const Mod* op) final { - return VisitBinaryExpr_(op); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const FloorDiv* op) final { - IntervalSet a = this->Eval(op->a); - IntervalSet b = this->Eval(op->b); - if (MatchPoint(a, op->a) && (b->min_value.same_as(op->b) || b->max_value.same_as(op->b))) { - // e.g., floordiv(10, [0, n]) - // if using VisitBinaryExpr_ it will be inferred as IntervalSet::Everything() - return IntervalSet::SinglePoint(GetRef(op)); - } - return Combine(analyzer_, a, b); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const FloorMod* op) final { - IntervalSet a = this->Eval(op->a); - IntervalSet b = this->Eval(op->b); - if (MatchPoint(a, op->a) && (b->min_value.same_as(op->b) || b->max_value.same_as(op->b))) { - // e.g., floormod(10, [0, n]) - // if using VisitBinaryExpr_ it will be inferred as IntervalSet::Everything() - return IntervalSet::SinglePoint(GetRef(op)); - } - return Combine(analyzer_, a, b); + return VisitDivExpr_(op); } IntervalSet VisitExpr_(const Min* op) final { @@ -549,6 +535,18 @@ class IntervalSetEvaluator : return set->min_value.same_as(value) && set->max_value.same_as(value); } + bool BoundedBySelf(const Expr& op) const { + if (const Call* call = op.as()) { + if (call->is_intrinsic(intrinsic::tvm_assert_bound)) { + Expr value = call->args[0]; + Expr lb = call->args[1]; + Expr ub = call->args[2]; + return lb.same_as(value) || ub.same_as(value); + } + } + return false; + } + template inline IntervalSet VisitBinaryExpr_(const T* op) { IntervalSet a = this->Eval(op->a); @@ -559,6 +557,22 @@ class IntervalSetEvaluator : return Combine(analyzer_, a, b); } + template + inline IntervalSet VisitDivExpr_(const T* op) { + IntervalSet a = this->Eval(op->a); + IntervalSet b = this->Eval(op->b); + if ((MatchPoint(a, op->a) && (MatchPoint(b, op->b) || BoundedBySelf(op->b))) + || (BoundedBySelf(op->a) && BoundedBySelf(op->b))) { + // e.g., + // div(10, 5) evaluates to 2 + // div(10, assert_bound(n, 0, n)) to itself + // div(assert_bound(m, 0, m), assert_bound(n, 0, n)) to itself + return IntervalSet::SinglePoint(GetRef(op)); + } + // e.g., div(assert_bound(m, 0, m), 2) goes here + return Combine(analyzer_, a, b); + } + // recursive depth int recur_depth_{0}; // analyzer diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c7c0492911d6..026dc48eb402 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -837,6 +837,7 @@ def test_assert_bound_simplify(): ck = RewriteChecker() x = tvm.var("x") ck.verify(tvm.assert_bound(x, 0) + 1 >= 1, tvm.const(True, "bool")) + ck.verify(tvm.assert_bound(x, 0, 10) + 1 <= 11, tvm.const(True, "bool")) if __name__ == "__main__": test_floordiv_index_simplify() From 768b95b6746325192d1e758c1f24f4f654ad9128 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 31 Dec 2019 16:36:34 -0800 Subject: [PATCH 23/24] fix build --- src/pass/remove_intrin.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pass/remove_intrin.cc b/src/pass/remove_intrin.cc index 1553df72f212..5ffe6c0ecac5 100644 --- a/src/pass/remove_intrin.cc +++ b/src/pass/remove_intrin.cc @@ -49,7 +49,7 @@ Expr RemoveIntrinExpr(Expr expr) { } LoweredFunc RemoveIntrin(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = RemoveIntrinStmt(n->body); return LoweredFunc(n); } From a905f1f974271fd54ed744470dda48ebc4e660b9 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 1 Jan 2020 13:17:45 -0800 Subject: [PATCH 24/24] fix bound remover --- src/arithmetic/bound_deducer.cc | 11 +++++------ tests/python/unittest/test_arith_deduce_bound.py | 10 ++++++++++ tests/python/unittest/test_arith_rewrite_simplify.py | 3 +++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 2e67f81eabc5..daa9c281203a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -76,7 +76,6 @@ class BoundRemover : public IRMutator { } Expr Reset(Expr e) { - CHECK(remove_bounded_) << "Call Do(expr) first."; remove_bounded_ = false; return IRMutator::Mutate(e); } @@ -339,9 +338,9 @@ void BoundDeducer::Deduce() { // If we eagerly simplified the left side given assert_bound(n, 0, +inf) // we would get i + 0 >= n => i >= n, which is obviously incorrect. // Thus we remove assert_bound here and reset later. - BoundRemover ra, rb; - expr_ = ra.Remove(expr_); - result_ = rb.Remove(result_); + BoundRemover bound_remover; + expr_ = bound_remover.Remove(expr_); + result_ = bound_remover.Remove(result_); Relax(); if (!success_) return; @@ -355,8 +354,8 @@ void BoundDeducer::Deduce() { Visit(expr_); - expr_ = ra.Reset(expr_); - result_ = rb.Reset(result_); + expr_ = bound_remover.Reset(expr_); + result_ = bound_remover.Reset(result_); } void BoundDeducer::Relax() { diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 33e31c766950..055b2b747229 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -213,8 +213,18 @@ def test_complex(a1, a2, coff): test_complex(2, 6, -4) +def test_deduce_assert_bound(): + i = tvm.var('i') + x = tvm.assert_bound(tvm.var('x'), 0) + + res = tvm.arith.DeduceBound(i, i+x < x, {}, {}) + assert str(res.min_value) == "neg_inf" + assert tvm.ir_pass.Simplify(res.max_value).value == -1 + + if __name__ == "__main__": test_check() test_deduce() test_deduce_basic() test_deduce_complex() + test_deduce_assert_bound() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 026dc48eb402..c38562228a46 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -836,8 +836,11 @@ def test_cast_simplify(): def test_assert_bound_simplify(): ck = RewriteChecker() x = tvm.var("x") + y = tvm.var("y") + ck.verify(tvm.assert_bound(tvm.assert_bound(x, 0), 0), tvm.assert_bound(x, 0)) ck.verify(tvm.assert_bound(x, 0) + 1 >= 1, tvm.const(True, "bool")) ck.verify(tvm.assert_bound(x, 0, 10) + 1 <= 11, tvm.const(True, "bool")) + ck.verify(tvm.floordiv(tvm.assert_bound(x, 0, 10), tvm.assert_bound(y, 0)) >= 0, tvm.const(True, "bool")) if __name__ == "__main__": test_floordiv_index_simplify()