From 2fad3f1237e8462df6f56c16d92c4ee0cb75941f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 6 Feb 2021 16:05:48 +0900 Subject: [PATCH 01/29] add while node --- include/tvm/tir/stmt.h | 47 ++++++++++++++++++++++++++++++++++++++++++ src/tir/ir/stmt.cc | 31 ++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 074bcdd3f533..81712502c9c1 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -861,6 +861,53 @@ class For : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); }; +/*! + * \brief A While loop + * + * \code + * + * while (condition) + * body + * + * \endcode + */ +class WhileNode : public StmtNode { + public: + /*! \brief The minimum value of iteration. */ + PrimExpr condition; + /*! \brief The body of the for loop. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("condition", &condition); + v->Visit("body", &body); + v->Visit("span", &span); + } + + bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const { + return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(condition); + hash_reduce.DefHash(body); + } + + static constexpr const char* _type_key = "tir.While"; + TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode); +}; + +/*! + * \brief Managed reference to WhileNode. + * \sa WhileNode + */ +class While : public Stmt { + public: + TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode); +}; + /*! * \brief A prefetch hint for a buffer */ diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e54be4347c8e..ae58202748cd 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -197,6 +197,37 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +// While +While::While(PrimExpr condition, Stmt body, Span span) { + ICHECK(condition.defined()); + ICHECK(condition.dtype().is_scalar()); + ICHECK(body.defined()); + + ObjectPtr node = make_object(); + node->condition = std::move(condition); + node->body = std::move(body); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { + return While(condition, body, span); +}); + +TVM_REGISTER_NODE_TYPE(WhileNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "while(" << op->condition << "){\n"; + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }); + // Store Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) { ICHECK(value.defined()); From 4382dbfed946570d28c12fa3cde4c16ba7ce69a8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 6 Feb 2021 16:20:07 +0900 Subject: [PATCH 02/29] update visitors --- include/tvm/tir/stmt_functor.h | 3 +++ src/tir/ir/stmt_functor.cc | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index e53b02d73e1d..31e2c55be7d1 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -86,6 +86,7 @@ class StmtFunctor { virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -152,6 +153,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; @@ -245,6 +247,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const LetStmtNode* op) override; Stmt VisitStmt_(const ForNode* op) override; + Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index f05dc7116494..639d38db0a81 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -45,6 +45,11 @@ void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitStmt(op->body); } +void StmtVisitor::VisitStmt_(const WhileNode* op) { + this->VisitExpr(op->condition); + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const AllocateNode* op) { VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitStmt(op->body); @@ -283,6 +288,19 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { } } +Stmt StmtMutator::VisitStmt_(const WhileNode* op) { + PrimExpr condition = this->VisitExpr(op->condition); + Stmt body = this->VisitStmt(op->body); + if (condition.same_as(op->condition) && body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->condition = std::move(condition); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); From 81ddea4f0d467b7f80e02c3ff5637a0ec147c3f4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 6 Feb 2021 20:15:09 +0900 Subject: [PATCH 03/29] binary search lowering works --- include/tvm/tir/stmt_functor.h | 1 + python/tvm/tir/ir_builder.py | 8 ++ python/tvm/tir/stmt.py | 12 +++ src/printer/text_printer.h | 1 + src/printer/tir_text_printer.cc | 7 ++ src/tir/transforms/storage_rewrite.cc | 2 + tests/python/unittest/test_tir_ir_builder.py | 101 ++++++++++++++++++- 7 files changed, 127 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 31e2c55be7d1..ceebbbb305ce 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -112,6 +112,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); IR_STMT_FUNCTOR_DISPATCH(ForNode); + IR_STMT_FUNCTOR_DISPATCH(WhileNode); IR_STMT_FUNCTOR_DISPATCH(AllocateNode); IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 437e8f6610f4..a674325e25cd 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -263,6 +263,14 @@ def _exit_cb(): return WithScope(loop_var, _exit_cb) + def while_loop(self, condition): + """TODO""" + self._seq_stack.append([]) + def _exit_cb(): + self.emit(_stmt.While(condition, self._pop_seq())) + + return WithScope(None, _exit_cb) + def if_scope(self, cond): """Create an if scope. diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index e4f1ac924a83..f3b734c9e0ef 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -158,6 +158,18 @@ def __init__( span, ) +@tvm._ffi.register_object("tir.While") +class While(Stmt): + """TODO""" + + def __init__(self, condition, body, span=None): + self.__init_handle_by_constructor__( + _ffi_api.While, + condition, + body, + span, + ) + @tvm._ffi.register_object("tir.Store") class Store(Stmt): diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 9a24fe65b4b1..6ec32a9e104c 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -308,6 +308,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const WhileNode* op) override; Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmtDefault_(const Object* op) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 711af2a8fd08..8d5bba5e5bb0 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -494,6 +494,13 @@ Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) { + Doc doc; + doc << "while (" << Print(op->condition) << ")"; + doc << PrintBody(op->body); + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { Doc doc; doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 0b1429ca7efa..1281d09f2ec8 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -192,6 +192,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } + void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); } + void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } // linearized access sequence. diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index b84ee09b9fd9..d43cccd8ca67 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,9 +173,100 @@ def check_target(target): check_target("cuda") +def test_binary_search(): + def binary_search(ib, n, i, Aptr, Bptr, Cptr): + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = n + v = Bptr[i] + + with ib.while_loop(lo[0] < hi[0]) as _: + mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32") + with ib.if_scope(Aptr[mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + Cptr[i] = lo[0] + + def searchsorted_ir_cpu(A, B, C, n): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, n, name="i", kind="parallel") as i: + binary_search(ib, n, i, Aptr, Bptr, Cptr) + + body = ib.get() + + return body + + def searchsorted_ir_gpu(A, B, C, n): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = 32 + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < n): + binary_search(ib, n, tid, Aptr, Bptr, Cptr) + + body = ib.get() + + return body + + n = 1024 + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + + def check_target(target, ir): + if not tvm.testing.device_enabled(target): + return + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: ir(ins[0], ins[1], outs[0], n), + name="searchsorted_ir", + dtype="int32", + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): + print(tvm.lower(s, [A, B, C], simple_mode=True)) + return + func = tvm.build(s, [A, B, C], target) + + ctx = tvm.context(target, 0) + a_np = np.random.uniform(size=n).astype(A.dtype) + b_np = np.random.uniform(size=n).astype(B.dtype) + a_np = np.sort(a_np) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(a, b, c) + ref = np.searchsorted(a_np, b_np) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("llvm", searchsorted_ir_cpu) + # check_target("cuda", searchsorted_ir_gpu) + # check_target("nvptx", searchsorted_ir_gpu) + + if __name__ == "__main__": - test_prefetch() - test_if() - test_for() - test_cpu() - test_gpu() + # test_prefetch() + # test_if() + # test_for() + # test_cpu() + # test_gpu() + test_binary_search() From 322d025a0cab7d6ec72a8c6fef1f4c36ea88668f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 Feb 2021 08:51:43 +0900 Subject: [PATCH 04/29] llvm codegen working --- src/target/llvm/codegen_llvm.cc | 14 ++++++++++++++ src/target/llvm/codegen_llvm.h | 1 + 2 files changed, 15 insertions(+) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 1dd76f6b9d51..d5140677d45a 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1328,6 +1328,20 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); } +void CodeGenLLVM::VisitStmt_(const WhileNode* op) { + using llvm::BasicBlock; + BasicBlock* while_cond = BasicBlock::Create(*ctx_, "while_cond", function_); + BasicBlock* while_body = BasicBlock::Create(*ctx_, "while_body", function_); + BasicBlock* while_merge = BasicBlock::Create(*ctx_, "while_merge", function_); + builder_->CreateBr(while_cond); + builder_->SetInsertPoint(while_cond); + builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); + builder_->SetInsertPoint(while_body); + this->VisitStmt(op->body); + builder_->CreateBr(while_cond); + builder_->SetInsertPoint(while_merge); +} + void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { using llvm::BasicBlock; llvm::Value* cond = MakeValue(op->condition); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 71583708da2c..e56a6de6d914 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -152,6 +152,7 @@ class CodeGenLLVM : public ExprFunctor, // stmt void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; From 45647b1a08266952119d436096cb9ac63b60d290 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 Feb 2021 09:20:06 +0900 Subject: [PATCH 05/29] cuda codegen working --- src/target/source/codegen_c.cc | 11 ++++++++++- src/target/source/codegen_c.h | 1 + tests/python/unittest/test_tir_ir_builder.py | 16 +++++++--------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index af175c7f2208..55db59f8d842 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -728,7 +728,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { ICHECK(is_one(op->predicate)) << "Predicated store is not supported"; arith::PVar base; - if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { std::string value = this->PrintExpr(op->value); this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value); @@ -899,6 +898,16 @@ void CodeGenC::VisitStmt_(const ForNode* op) { stream << "}\n"; } +void CodeGenC::VisitStmt_(const WhileNode* op) { + PrintIndent(); + stream << "while (" << PrintExpr(op->condition) << ") {\n"; + int while_scope = BeginScope(); + PrintStmt(op->body); + this->EndScope(while_scope); + PrintIndent(); + stream << "}\n"; +} + void CodeGenC::VisitStmt_(const IfThenElseNode* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index c1b566c064a4..76e6a9bc7197 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -150,6 +150,7 @@ class CodeGenC : public ExprFunctor, void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index d43cccd8ca67..60e58923724c 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -243,8 +243,6 @@ def check_target(target, ir): s = te.create_schedule(C.op) with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): - print(tvm.lower(s, [A, B, C], simple_mode=True)) - return func = tvm.build(s, [A, B, C], target) ctx = tvm.context(target, 0) @@ -259,14 +257,14 @@ def check_target(target, ir): tvm.testing.assert_allclose(c.asnumpy(), ref) check_target("llvm", searchsorted_ir_cpu) - # check_target("cuda", searchsorted_ir_gpu) - # check_target("nvptx", searchsorted_ir_gpu) + check_target("cuda", searchsorted_ir_gpu) + check_target("nvptx", searchsorted_ir_gpu) if __name__ == "__main__": - # test_prefetch() - # test_if() - # test_for() - # test_cpu() - # test_gpu() + test_prefetch() + test_if() + test_for() + test_cpu() + test_gpu() test_binary_search() From 5592b255567818ea13c4da18a683576198427ee5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 Feb 2021 09:31:26 +0900 Subject: [PATCH 06/29] nms updated to use while loop --- python/tvm/topi/cuda/nms.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 152b1bd15987..c62009de5354 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -555,16 +555,20 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.for_range(0, nkeep) as j: - # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out_scores[i, j] > -1.0): - with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size - # boxes - # TODO(masahi): Add TIR while loop to realize early exit from the outer loop - with ib.if_scope(num_valid_boxes_local[0] < max_output_size): - nms_inner_loop(ib, j) - with ib.else_scope(): + with ib.if_scope(max_output_size > 0): + # No need to do more iteration if we have already reached max_output_size boxes + box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx[0] = 0 + with ib.while_loop(num_valid_boxes_local[0] < max_output_size): + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): + nms_inner_loop(ib, box_idx[0]) + box_idx[0] += 1 + + with ib.else_scope(): + with ib.for_range(0, nkeep) as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out_scores[i, j] > -1.0): nms_inner_loop(ib, j) with ib.if_scope(tx + 0 == 0): From 7896302cb6624c39a078b130d08398400bff50b8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 Feb 2021 13:29:33 +0900 Subject: [PATCH 07/29] add missing upper bound check too --- python/tvm/topi/cuda/nms.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index c62009de5354..83b538554ed4 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -521,7 +521,7 @@ def nms_inner_loop(ib, j): offset_j = j * 4 num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) - with ib.for_range(0, num_iter_per_thread) as _k: + with ib.for_range(0, num_iter_per_thread, name="_k") as _k: k = j + 1 + _k * nthread_tx + tx offset_k = k * 4 @@ -559,14 +559,16 @@ def nms_inner_loop(ib, j): # No need to do more iteration if we have already reached max_output_size boxes box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") box_idx[0] = 0 - with ib.while_loop(num_valid_boxes_local[0] < max_output_size): - # Proceed to the inner loop if the box j is still valid + with ib.while_loop( + tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) + ): + # Proceed to the inner loop if the box with id box_idx is still valid with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): nms_inner_loop(ib, box_idx[0]) box_idx[0] += 1 with ib.else_scope(): - with ib.for_range(0, nkeep) as j: + with ib.for_range(0, nkeep, name="j") as j: # Proceed to the inner loop if the box j is still valid with ib.if_scope(out_scores[i, j] > -1.0): nms_inner_loop(ib, j) From f041f85ae53522f10fe687c0ee3b5cd159f86f57 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 Feb 2021 16:22:41 +0900 Subject: [PATCH 08/29] add mandelbrot test --- tests/python/unittest/test_tir_ir_builder.py | 103 +++++++++++++++++-- 1 file changed, 96 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 60e58923724c..fcd47b8e8899 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,6 +173,94 @@ def check_target(target): check_target("cuda") +def test_while(): + n = 160 + shape = (n * 2, n) + t = 300 + + def mandel_ref(): + def complex_sqr(z): + return np.array([z[0] ** 2 - z[1] ** 2, z[1] * z[0] * 2]) + + pixels = np.zeros(shape) + + for i in range(pixels.shape[0]): + for j in range(pixels.shape[1]): + c = np.array([-0.8, np.cos(t) * 0.2]) + z = np.array([i / n - 1, j / n - 0.5]) * 2 + iterations = 0 + + while np.linalg.norm(z) < 20 and iterations < 50: + z = complex_sqr(z) + c + iterations += 1 + + pixels[i, j] = 1 - iterations * 0.02 + + return pixels + + def mandel(ib, i, j, pixels): + c = ib.allocate("float32", (2,), name="c", scope="local") + z = ib.allocate("float32", (2,), name="z", scope="local") + tmp = ib.allocate("float32", (1,), name="1", scope="local") + iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + + c[0] = -0.8 + c[1] = float(np.cos(t)) * 0.2 + z[0] = (i / float(n) - 1) * 2 + z[1] = (j / float(n) - 0.5) * 2 + iterations[0] = 0 + + def norm(z): + return tvm.tir.sqrt(z[0] * z[0] + z[1] * z[1]) + + with ib.while_loop(tvm.tir.all(norm(z) < 20, iterations[0] < 50)): + tmp[0] = z[0] + z[0] = z[0] * z[0] - z[1] * z[1] + c[0] + z[1] = z[1] * tmp[0] * 2 + c[1] + iterations[0] += 1 + + pixels[i, j] = 1 - iterations[0] * 0.02 + + def mandel_ir_cpu(C): + ib = tvm.tir.ir_builder.create() + ny = C.shape[0] + nx = C.shape[1] + C = ib.buffer_ptr(C) + + with ib.for_range(0, ny, name="i", kind="parallel") as i: + with ib.for_range(0, nx, name="j") as j: + mandel(ib, i, j, C) + + body = ib.get() + + return body + + ref = mandel_ref() + + def check_target(target, ir): + if not tvm.testing.device_enabled(target): + return + + C = te.extern( + shape, + [], + lambda ins, outs: ir(outs[0]), + name="mandel_ir", + dtype="float32", + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [C], target) + + ctx = tvm.context(target, 0) + c = tvm.nd.array(np.zeros(shape, dtype=C.dtype), ctx) + func(c) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("llvm", mandel_ir_cpu) + + def test_binary_search(): def binary_search(ib, n, i, Aptr, Bptr, Cptr): lo = ib.allocate("int32", (1,), name="lo", scope="local") @@ -182,7 +270,7 @@ def binary_search(ib, n, i, Aptr, Bptr, Cptr): hi[0] = n v = Bptr[i] - with ib.while_loop(lo[0] < hi[0]) as _: + with ib.while_loop(lo[0] < hi[0]): mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32") with ib.if_scope(Aptr[mid] < v): lo[0] = mid + 1 @@ -242,7 +330,7 @@ def check_target(target, ir): ) s = te.create_schedule(C.op) - with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): + with tvm.transform.PassContext(opt_level=3): func = tvm.build(s, [A, B, C], target) ctx = tvm.context(target, 0) @@ -262,9 +350,10 @@ def check_target(target, ir): if __name__ == "__main__": - test_prefetch() - test_if() - test_for() - test_cpu() - test_gpu() + # test_prefetch() + # test_if() + # test_for() + # test_cpu() + # test_gpu() + test_while() test_binary_search() From b53150c02d0255766a7de4dd1e46362aabbc8673 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 7 Feb 2021 16:41:27 +0900 Subject: [PATCH 09/29] add gpu mandel commit ee2363bf8131830cf0fb112890befd6be6a03f36 Author: Masahiro Masuda Date: Fri Jan 29 11:44:02 2021 +0900 enable extern lib offload for nvptx --- tests/python/unittest/test_tir_ir_builder.py | 41 +++++++++++++++++--- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index fcd47b8e8899..fb325ac2ae0e 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -201,7 +201,7 @@ def complex_sqr(z): def mandel(ib, i, j, pixels): c = ib.allocate("float32", (2,), name="c", scope="local") z = ib.allocate("float32", (2,), name="z", scope="local") - tmp = ib.allocate("float32", (1,), name="1", scope="local") + tmp = ib.allocate("float32", (1,), name="tmp", scope="local") iterations = ib.allocate("int32", (1,), name="iterations", scope="local") c[0] = -0.8 @@ -235,6 +235,33 @@ def mandel_ir_cpu(C): return body + def mandel_ir_gpu(C): + ib = tvm.tir.ir_builder.create() + ny = C.shape[0] + nx = C.shape[1] + C = ib.buffer_ptr(C) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + by = te.thread_axis("blockIdx.y") + ty = te.thread_axis("threadIdx.y") + + max_threads = 16 + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(nx + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + ib.scope_attr(by, "thread_extent", tvm.tir.indexdiv(ny + max_threads - 1, max_threads)) + ib.scope_attr(ty, "thread_extent", max_threads) + + tidx = bx * max_threads + tx + tidy = by * max_threads + ty + + with ib.if_scope(tvm.tir.all(tidx < nx, tidy < ny)): + mandel(ib, tidy, tidx, C) + + body = ib.get() + + return body + ref = mandel_ref() def check_target(target, ir): @@ -259,6 +286,8 @@ def check_target(target, ir): tvm.testing.assert_allclose(c.asnumpy(), ref) check_target("llvm", mandel_ir_cpu) + check_target("npvtx", mandel_ir_gpu) + check_target("cuda", mandel_ir_gpu) def test_binary_search(): @@ -350,10 +379,10 @@ def check_target(target, ir): if __name__ == "__main__": - # test_prefetch() - # test_if() - # test_for() - # test_cpu() - # test_gpu() + test_prefetch() + test_if() + test_for() + test_cpu() + test_gpu() test_while() test_binary_search() From 265cffcdb056254f136514001dae0ec3132dbfd7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Feb 2021 14:34:14 +0900 Subject: [PATCH 10/29] rename test --- tests/python/unittest/test_tir_ir_builder.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index fb325ac2ae0e..108778331623 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,7 +173,7 @@ def check_target(target): check_target("cuda") -def test_while(): +def test_while_mandel(): n = 160 shape = (n * 2, n) t = 300 @@ -290,7 +290,11 @@ def check_target(target, ir): check_target("cuda", mandel_ir_gpu) -def test_binary_search(): +def test_collatz(): + pass + + +def test_while_binary_search(): def binary_search(ib, n, i, Aptr, Bptr, Cptr): lo = ib.allocate("int32", (1,), name="lo", scope="local") hi = ib.allocate("int32", (1,), name="hi", scope="local") @@ -384,5 +388,6 @@ def check_target(target, ir): test_for() test_cpu() test_gpu() - test_while() - test_binary_search() + test_while_mandel() + test_while_collatz() + test_while_binary_search() From 0b6a93b18a02bba6d988f116b9ff7f39c95f3d33 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Feb 2021 15:01:09 +0900 Subject: [PATCH 11/29] run black --- python/tvm/tir/ir_builder.py | 1 + python/tvm/tir/stmt.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index a674325e25cd..f5df971c53cb 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -266,6 +266,7 @@ def _exit_cb(): def while_loop(self, condition): """TODO""" self._seq_stack.append([]) + def _exit_cb(): self.emit(_stmt.While(condition, self._pop_seq())) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index f3b734c9e0ef..ec84cf69f27e 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -158,6 +158,7 @@ def __init__( span, ) + @tvm._ffi.register_object("tir.While") class While(Stmt): """TODO""" From 3dfe1894172d4091d03fc6bc0dafbdf6ed73080b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Feb 2021 21:00:55 +0900 Subject: [PATCH 12/29] add doc --- include/tvm/tir/stmt.h | 4 ++-- python/tvm/tir/ir_builder.py | 22 +++++++++++++++++++++- python/tvm/tir/stmt.py | 14 +++++++++++++- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 81712502c9c1..ac660bfb7461 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -873,9 +873,9 @@ class For : public Stmt { */ class WhileNode : public StmtNode { public: - /*! \brief The minimum value of iteration. */ + /*! \brief The termination condition. */ PrimExpr condition; - /*! \brief The body of the for loop. */ + /*! \brief The body of the while loop. */ Stmt body; void VisitAttrs(AttrVisitor* v) { diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index f5df971c53cb..2ecbdeda8371 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -264,7 +264,27 @@ def _exit_cb(): return WithScope(loop_var, _exit_cb) def while_loop(self, condition): - """TODO""" + """Create a while loop scope. + + Parameters + ---------- + condition : Expr + The termination condition. + + Returns + ------- + loop_scope : With.Scope of Var + The while scope. + + Examples + -------- + .. code-block:: python + + ib = tvm.tir.ir_builder.create() + iterations = ib.allocate("int32", (1,), name="iterations", scope="local") + with ib.while_loop(iterations[0] < 10): + iterations[0] += 1 + """ self._seq_stack.append([]) def _exit_cb(): diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index ec84cf69f27e..47462066c364 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -161,7 +161,19 @@ def __init__( @tvm._ffi.register_object("tir.While") class While(Stmt): - """TODO""" + """While node. + + Parameters + ---------- + condition : PrimExpr + The termination condition. + + body : Stmt + The body statement. + + span : Optional[Span] + The location of this itervar in the source code. + """ def __init__(self, condition, body, span=None): self.__init_handle_by_constructor__( From 6834cc39ee9009d22f517f240c5d5c1b8928f934 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Feb 2021 21:24:42 +0900 Subject: [PATCH 13/29] add collatz test --- tests/python/unittest/test_tir_ir_builder.py | 72 ++++++++++++++++++-- 1 file changed, 66 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 108778331623..a2cb26e7dc69 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,6 +173,70 @@ def check_target(target): check_target("cuda") +def test_while_collatz(): + """Test while loop + if""" + + def collatz_ref(n): + a = n + i = 0 + while a > 1: + if a % 2 == 1: + a = 3 * a + 1 + else: + a = a >> 1 + i += 1 + return i + + def collatz(ib, n, C): + i = ib.allocate("int32", (1,), name="i", scope="local") + a = ib.allocate("int32", (1,), name="a", scope="local") + i[0] = 0 + a[0] = n + with ib.while_loop(a[0] > 1): + with ib.if_scope(tvm.tir.floormod(a[0], 2) == 1): + a[0] = 3 * a[0] + 1 + with ib.else_scope(): + a[0] = a[0] >> 1 + i[0] += 1 + + C[n] = i[0] + + def collatz_ir_cpu(C): + ib = tvm.tir.ir_builder.create() + n = C.shape[0] + C = ib.buffer_ptr(C) + + with ib.for_range(0, n, name="i", kind="parallel") as i: + collatz(ib, i, C) + + body = ib.get() + + return body + + n = 30 + + def check_target(target, ir): + C = te.extern( + (n,), + [], + lambda ins, outs: ir(outs[0]), + name="collatz", + dtype="int32", + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [C], target) + + ctx = tvm.context(target, 0) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(c) + ref = np.array([collatz_ref(i) for i in range(n)]) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("llvm", collatz_ir_cpu) + + def test_while_mandel(): n = 160 shape = (n * 2, n) @@ -283,17 +347,13 @@ def check_target(target, ir): ctx = tvm.context(target, 0) c = tvm.nd.array(np.zeros(shape, dtype=C.dtype), ctx) func(c) - tvm.testing.assert_allclose(c.asnumpy(), ref) + tvm.testing.assert_allclose(c.asnumpy(), ref, rtol=1e-5, atol=1e-5) check_target("llvm", mandel_ir_cpu) check_target("npvtx", mandel_ir_gpu) check_target("cuda", mandel_ir_gpu) -def test_collatz(): - pass - - def test_while_binary_search(): def binary_search(ib, n, i, Aptr, Bptr, Cptr): lo = ib.allocate("int32", (1,), name="lo", scope="local") @@ -388,6 +448,6 @@ def check_target(target, ir): test_for() test_cpu() test_gpu() - test_while_mandel() test_while_collatz() + test_while_mandel() test_while_binary_search() From 92d9add9ffd8908c645e489766046d0064c2e9dc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Feb 2021 21:41:08 +0900 Subject: [PATCH 14/29] add while + vectorize test --- tests/python/unittest/test_tir_ir_builder.py | 56 ++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index a2cb26e7dc69..a75e50446764 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,6 +173,61 @@ def check_target(target): check_target("cuda") +def test_while_vectorize(): + """Test while loop + vectorized inner loop""" + + n = 64 + num_iter = 10 + + def test_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + n = C.shape[0] + A = ib.buffer_ptr(A) + B = ib.buffer_ptr(B) + C = ib.buffer_ptr(C) + i = ib.allocate("int32", (1,), name="i", scope="local") + i[0] = 0 + + with ib.for_range(0, n) as j: + C[j] = 0.0 + + with ib.while_loop(i[0] < num_iter): + with ib.for_range(0, n, kind="vectorize") as j: + C[j] += A[j] + B[j] + i[0] += 1 + + return ib.get() + + def check_target(target, ir): + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="while_vectorize", + dtype=dtype, + ) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [A, B, C], target) + + ctx = tvm.context(target, 0) + a_np = np.random.uniform(size=n).astype(A.dtype) + b_np = np.random.uniform(size=n).astype(B.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(a, b, c) + ref = num_iter * (a_np + b_np) + tvm.testing.assert_allclose(c.asnumpy(), ref, rtol=1e-5, atol=1e-5) + + check_target("llvm", test_ir) + + def test_while_collatz(): """Test while loop + if""" @@ -448,6 +503,7 @@ def check_target(target, ir): test_for() test_cpu() test_gpu() + test_while_vectorize() test_while_collatz() test_while_mandel() test_while_binary_search() From e56d5707e9d02038a8bec5a86bceba5b2d91e75e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Feb 2021 21:46:18 +0900 Subject: [PATCH 15/29] simplify bin search --- tests/python/unittest/test_tir_ir_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index a75e50446764..428498f820f0 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -419,7 +419,7 @@ def binary_search(ib, n, i, Aptr, Bptr, Cptr): v = Bptr[i] with ib.while_loop(lo[0] < hi[0]): - mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32") + mid = lo[0] + (hi[0] - lo[0] >> 1) with ib.if_scope(Aptr[mid] < v): lo[0] = mid + 1 with ib.else_scope(): From ef6427858355b1970e89057012d8321d566db06e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Feb 2021 21:17:54 +0900 Subject: [PATCH 16/29] Add special case visit method to storage_access.cc --- src/tir/transforms/storage_access.cc | 13 +++++++++++++ src/tir/transforms/storage_access.h | 1 + 2 files changed, 14 insertions(+) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index be20724ae207..38143c14b021 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -180,6 +180,19 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { --condition_counter_; } +void StorageAccessVisitor::VisitStmt_(const WhileNode* op) { + ++condition_counter_; + this->VisitExpr(op->condition); + scope_.push_back(std::vector()); + this->VisitStmt(op->body); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + scope_.back().emplace_back(std::move(s)); + --condition_counter_; +} + void StorageAccessVisitor::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 80bbff4c1fe4..663c570fd15c 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -84,6 +84,7 @@ class StorageAccessVisitor : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const ForNode* op) final; void VisitStmt_(const IfThenElseNode* op) final; + void VisitStmt_(const WhileNode* op) final; void VisitExpr_(const CallNode* op) final; protected: From ff86f1652deab7a54703c665454b097b9d28b841 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Feb 2021 21:31:25 +0900 Subject: [PATCH 17/29] disallow while loop inside vectorized loop --- src/tir/transforms/vectorize_loop.cc | 7 ++- tests/python/unittest/test_tir_ir_builder.py | 51 ++++++++++++++++++-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 66f4ae329f69..64956bc8ee54 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -388,6 +388,11 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); @@ -441,7 +446,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor Date: Wed, 10 Feb 2021 21:40:46 +0900 Subject: [PATCH 18/29] disallow trivial condition since we do not have break --- src/tir/ir/stmt.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index ae58202748cd..2aeaae3eb592 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -201,6 +201,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) While::While(PrimExpr condition, Stmt body, Span span) { ICHECK(condition.defined()); ICHECK(condition.dtype().is_scalar()); + ICHECK(condition.as() == nullptr) << "The condition should not be trivial."; ICHECK(body.defined()); ObjectPtr node = make_object(); From 220c7eba2e70f0eb8f8e472b9a477ed861a74ed5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Feb 2021 21:48:32 +0900 Subject: [PATCH 19/29] error out in CoprocSync for now --- src/tir/transforms/coproc_sync.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index f9245442d268..424a1bbb0ae6 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -429,6 +429,11 @@ class CoProcInstDepDetector : public StmtVisitor { } } + void VisitStmt_(const WhileNode* op) final { + // TODO(masahi): Do we need a special handling for While nodes? + LOG(FATAL) << "WhileNode not supported in CoProcSync."; + } + // insert before is stored in reverse order // the first element is closest to the node. std::unordered_map > insert_before_; From 3817b5ae277f6d0d0a3b9a3db106b16a231ac406 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Feb 2021 21:52:37 +0900 Subject: [PATCH 20/29] error out LiftAttrScope for now --- src/tir/transforms/lift_attr_scope.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index 27dd583b8b42..40d152b3b3b6 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -157,6 +157,12 @@ class AttrScopeLifter : public StmtMutator { } } + Stmt VisitStmt_(const WhileNode* op) final { + // TODO(masahi): Do we need a special handling for While nodes? + LOG(FATAL) << "WhileNode not supported in LiftAttrScope."; + return Stmt(); + } + private: // value comparison that also compares content of int constant static bool ValueSame(const PrimExpr& a, const PrimExpr& b) { From 384ac45dfb03aa2c7eefa79002ed17df92a2de7d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Feb 2021 22:04:56 +0900 Subject: [PATCH 21/29] add placeholder to inject_vpthread --- src/tir/transforms/inject_virtual_thread.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index b24a0e95cd53..09ae21afcca4 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -333,6 +333,12 @@ class VTInjector : public StmtExprMutator { } } + // While + Stmt VisitStmt_(const WhileNode* op) final { + // TODO(masahi): Do we need a special handling for While nodes? + return StmtMutator::VisitStmt_(op); + } + // Seq Stmt VisitStmt_(const SeqStmtNode* op) final { ICHECK_EQ(max_loop_depth_, 0); From da3ca49218075af8e4f144af84adf7d4767804c1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Feb 2021 11:27:05 +0900 Subject: [PATCH 22/29] refactor to use MakeAttach --- src/tir/transforms/storage_rewrite.cc | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 1281d09f2ec8..313f93e449ca 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -352,16 +352,7 @@ class StoragePlanRewriter : public StmtExprMutator { // start rewrite stmt = operator()(std::move(stmt)); if (attach_map_.count(nullptr)) { - std::vector nest; - for (StorageEntry* e : attach_map_.at(nullptr)) { - // ICHECK_EQ(e->scope.rank, 0); - if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), Evaluate(0))); - nest.push_back(e->new_alloc); - } - } - stmt = MergeNest(nest, stmt); + return MakeAttach(attach_map_.at(nullptr), stmt); } return stmt; } From 626c7ff7c1a10d29d705cf6e0467cb7ff9e82fb8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Feb 2021 11:28:50 +0900 Subject: [PATCH 23/29] handle WhileNode in InplaceOpVerifier --- src/tir/transforms/storage_rewrite.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 313f93e449ca..124e4e2121be 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -246,6 +246,8 @@ class InplaceOpVerifier : public StmtExprVisitor { VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { VisitStmt_(static_cast(stmt)); + } else if (stmt->IsInstance()) { + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { VisitStmt_(static_cast(stmt)); } else { From 0fadb47fc98290d9c2de462716942f81125ec5e6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Feb 2021 11:30:02 +0900 Subject: [PATCH 24/29] error out in InjectVirtualThread --- src/tir/transforms/inject_virtual_thread.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 09ae21afcca4..4ef10f326bb0 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -335,8 +335,9 @@ class VTInjector : public StmtExprMutator { // While Stmt VisitStmt_(const WhileNode* op) final { - // TODO(masahi): Do we need a special handling for While nodes? - return StmtMutator::VisitStmt_(op); + // TODO(masahi): What should we do for While nodes? + LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet"; + return Stmt(); } // Seq From 45818ea3c97ffb9edcfd91bb7352d42a1d009532 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Feb 2021 11:32:14 +0900 Subject: [PATCH 25/29] try handle WhileNode in StoragePlanRewriter --- src/tir/transforms/storage_rewrite.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 124e4e2121be..90606acaf3eb 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -446,6 +446,18 @@ class StoragePlanRewriter : public StmtExprMutator { } } + Stmt VisitStmt_(const WhileNode* op) final { + // remake all the allocation at the attach scope. + if (attach_map_.count(op)) { + auto& svec = attach_map_[op]; + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + return While(op->condition, MakeAttach(svec, op->body)); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); } private: From f442ecc2550fc41b185e67364770f4b18d8cc166 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 2 Mar 2021 15:41:47 +0900 Subject: [PATCH 26/29] remove WhileNode visitor from storage rewrite --- src/tir/transforms/storage_rewrite.cc | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 90606acaf3eb..36eeddb17d89 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -432,6 +432,7 @@ class StoragePlanRewriter : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } } + Stmt VisitStmt_(const ForNode* op) final { ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. @@ -446,18 +447,6 @@ class StoragePlanRewriter : public StmtExprMutator { } } - Stmt VisitStmt_(const WhileNode* op) final { - // remake all the allocation at the attach scope. - if (attach_map_.count(op)) { - auto& svec = attach_map_[op]; - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - return While(op->condition, MakeAttach(svec, op->body)); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); } private: From 3012876d8a036a2370711e24fc531002382171c4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 2 Mar 2021 15:48:04 +0900 Subject: [PATCH 27/29] add while loop storage rewrite test --- tests/python/unittest/test_tir_ir_builder.py | 1 + .../test_tir_transform_storage_rewrite.py | 40 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 571e476c6649..a10bb2a86802 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -535,6 +535,7 @@ def test_ir(A, B, C): try: tvm.lower(s, [A, B, C], "llvm") + assert False except tvm.error.TVMError as e: error_msg = str(e).split("\n")[-1] expected = "A while loop inside a vectorized loop not supported" diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 49adcfb568a7..50bdc1eeec65 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -297,6 +297,46 @@ def test_parallel_alloc(): assert isinstance(body.body.body.body.body, tvm.tir.Allocate) + ib = tvm.tir.ir_builder.create() + n = te.var("n") + with ib.for_range(0, n, name="i", kind="parallel") as i: + j = ib.allocate("int32", 1, name="j", scope="global") + j[0] = 0 + with ib.while_loop(j[0] < 10): + A = ib.allocate("float32", n, name="A", scope="global") + A[j[0]] = A[j[0]] + 2 + j[0] += j[0] + 1 + + body = ib.get() + # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # j[0] = 0 + # while((j[0] < 10)){ + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + + # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # j[0] = 0 + # while((j[0] < 10)){ + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + assert isinstance(body.body.body, tvm.tir.Allocate) # j + assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): # Test Buffer From c3af5ae9aa611580004ce03d16aa952ab124d826 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 2 Mar 2021 16:47:58 +0900 Subject: [PATCH 28/29] update tests --- .../test_tir_transform_storage_rewrite.py | 57 ++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 50bdc1eeec65..dbe7e04700d9 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -297,17 +297,23 @@ def test_parallel_alloc(): assert isinstance(body.body.body.body.body, tvm.tir.Allocate) - ib = tvm.tir.ir_builder.create() - n = te.var("n") - with ib.for_range(0, n, name="i", kind="parallel") as i: - j = ib.allocate("int32", 1, name="j", scope="global") - j[0] = 0 - with ib.while_loop(j[0] < 10): - A = ib.allocate("float32", n, name="A", scope="global") - A[j[0]] = A[j[0]] + 2 - j[0] += j[0] + 1 - body = ib.get() +def test_while_alloc(): + def get_mod(kind="serial"): + ib = tvm.tir.ir_builder.create() + n = te.var("n") + with ib.for_range(0, n, name="i", kind=kind) as i: + j = ib.allocate("int32", 1, name="j", scope="global") + j[0] = 0 + with ib.while_loop(j[0] < 10): + A = ib.allocate("float32", n, name="A", scope="global") + A[j[0]] = A[j[0]] + 2 + j[0] += j[0] + 1 + + body = ib.get() + return tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + + mod = get_mod(kind="parallel") # parallel (i, 0, n) { # // attr [j] storage_scope = "global" # allocate j[int32 * 1] @@ -319,10 +325,7 @@ def test_parallel_alloc(): # j[0] = (j[0] + (j[0] + 1)) # } # } - - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) body = tvm.tir.transform.StorageRewrite()(mod)["main"].body - # parallel (i, 0, n) { # // attr [j] storage_scope = "global" # allocate j[int32 * 1] @@ -337,6 +340,33 @@ def test_parallel_alloc(): assert isinstance(body.body.body, tvm.tir.Allocate) # j assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + mod = get_mod(kind="serial") + # for (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # j[0] = 0 + # while((j[0] < 10)){ + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # for (i, 0, n) { + # j[0] = 0 + # while((j[0] < 10)){ + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + assert isinstance(body.body, tvm.tir.Allocate) # j + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): # Test Buffer @@ -616,6 +646,7 @@ def verify(n): test_alloc_different_dtypes() test_inplace_rule() test_parallel_alloc() + test_while_alloc() test_storage_combine() test_storage_share_gpu() test_inplace_rule2() From 35b8e2825d9caa323fda46f75afd5619cd76e16b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 3 Mar 2021 05:09:25 +0900 Subject: [PATCH 29/29] move test_vectorize_while_fail to test_tir_transform_vectorize.py --- tests/python/unittest/test_tir_ir_builder.py | 48 ------------------- .../unittest/test_tir_transform_vectorize.py | 48 +++++++++++++++++++ 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index a10bb2a86802..46bc500fc503 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -495,53 +495,6 @@ def check_target(target, ir): check_target("nvptx", searchsorted_ir_gpu) -def test_vectorize_while_fail(): - """A while loop inside a vectorized loop should fail.""" - - n = 64 - num_iter = 10 - - def test_ir(A, B, C): - ib = tvm.tir.ir_builder.create() - n = C.shape[0] - A = ib.buffer_ptr(A) - B = ib.buffer_ptr(B) - C = ib.buffer_ptr(C) - i = ib.allocate("int32", (1,), name="i", scope="local") - i[0] = 0 - - with ib.for_range(0, n) as j: - C[j] = 0.0 - - with ib.for_range(0, n, kind="vectorize") as j: - with ib.while_loop(i[0] < num_iter): - C[j] += A[j] + B[j] - i[0] += 1 - - return ib.get() - - dtype = "float32" - A = te.placeholder((n,), name="A", dtype=dtype) - B = te.placeholder((n,), name="B", dtype=dtype) - - C = te.extern( - (n,), - [A, B], - lambda ins, outs: test_ir(ins[0], ins[1], outs[0]), - name="while_vectorize", - dtype=dtype, - ) - s = te.create_schedule(C.op) - - try: - tvm.lower(s, [A, B, C], "llvm") - assert False - except tvm.error.TVMError as e: - error_msg = str(e).split("\n")[-1] - expected = "A while loop inside a vectorized loop not supported" - assert expected in error_msg - - if __name__ == "__main__": test_prefetch() test_if() @@ -552,4 +505,3 @@ def test_ir(A, B, C): test_while_collatz() test_while_mandel() test_while_binary_search() - test_vectorize_while_fail() diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index 5ae47e01f681..b1e580957b24 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -158,6 +158,53 @@ def test_vectorize_if_then_else(): assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast) +def test_vectorize_while_fail(): + """A while loop inside a vectorized loop should fail.""" + + n = 64 + num_iter = 10 + + def test_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + n = C.shape[0] + A = ib.buffer_ptr(A) + B = ib.buffer_ptr(B) + C = ib.buffer_ptr(C) + i = ib.allocate("int32", (1,), name="i", scope="local") + i[0] = 0 + + with ib.for_range(0, n) as j: + C[j] = 0.0 + + with ib.for_range(0, n, kind="vectorize") as j: + with ib.while_loop(i[0] < num_iter): + C[j] += A[j] + B[j] + i[0] += 1 + + return ib.get() + + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_ir(ins[0], ins[1], outs[0]), + name="while_vectorize", + dtype=dtype, + ) + s = te.create_schedule(C.op) + + try: + tvm.lower(s, [A, B, C], "llvm") + assert False + except tvm.error.TVMError as e: + error_msg = str(e).split("\n")[-1] + expected = "A while loop inside a vectorized loop not supported" + assert expected in error_msg + + if __name__ == "__main__": test_vectorize_vector() test_vectorize_with_if() @@ -166,3 +213,4 @@ def test_vectorize_if_then_else(): test_vectorize_with_le_cond() test_vectorize_with_ge_cond() test_vectorize_let() + test_vectorize_while_fail()