From af210b0029855688580ce23ee575ef1fdb3d9b4b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 30 Jan 2021 10:15:47 +0900 Subject: [PATCH 01/16] usage experiment of for loop with test --- python/tvm/tir/ir_builder.py | 2 +- python/tvm/topi/cuda/nms.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 437e8f6610f4..c8317500876e 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value): value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) - def for_range(self, begin, end, name="i", dtype="int32", kind="serial"): + def for_range(self, begin, end, test=None, name="i", dtype="int32", kind="serial"): """Create a for iteration scope. Parameters diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2d6e1e464ef8..7578482bd6c7 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -541,16 +541,17 @@ def nms_inner_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.for_range(0, nkeep) as j: - # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out_scores[i, j] > -1.0): - with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size - # boxes - # TODO(masahi): Add TIR while loop to realize early exit from the outer loop - with ib.if_scope(num_valid_boxes_local[0] < max_output_size): - nms_inner_loop(ib, j) - with ib.else_scope(): + with ib.if_scope(max_output_size > 0): + # No need to do more iteration if we have already reached max_output_size + # boxes + # TODO(masahi): Add TIR while loop to realize early exit from the outer loop + with ib.for_range(0, nkeep, test=num_valid_boxes_local[0] < max_output_size) as j: + nms_inner_loop(ib, j) + + with ib.else_scope(): + with ib.for_range(0, nkeep) as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out_scores[i, j] > -1.0): nms_inner_loop(ib, j) with ib.if_scope(tx + 0 == 0): From 3119de025566564b58329db4843cb41e1a10ac74 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 11:58:49 +0900 Subject: [PATCH 02/16] add test to for node and make it compile --- include/tvm/tir/stmt.h | 10 ++++++++-- python/tvm/tir/ir_builder.py | 5 ++++- python/tvm/tir/stmt.py | 2 ++ src/target/llvm/codegen_cpu.cc | 2 +- src/te/operation/hybrid_op.cc | 6 +++--- .../schedule_postproc_rewrite_for_tensor_core.cc | 2 +- src/tir/ir/stmt.cc | 8 +++++--- src/tir/transforms/ir_utils.cc | 2 +- src/tir/transforms/narrow_datatype.cc | 2 +- src/tir/transforms/storage_rewrite.cc | 2 +- src/tir/transforms/unroll_loop.cc | 2 +- src/tir/transforms/vectorize_loop.cc | 2 +- vta/python/vta/transform.py | 10 +++++++++- 13 files changed, 38 insertions(+), 17 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 093d49ca2dd4..4ddc9c0adf34 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -802,6 +802,9 @@ class ForNode : public StmtNode { ForKind kind; /*! \brief The body of the for loop. */ Stmt body; + /*! \brief The test condition of the for loop. */ + Optional test; + /*! * \brief Only valid when kind == ForKind::kThreadBinding * The context thread that this loop variable bounds to. @@ -823,6 +826,7 @@ class ForNode : public StmtNode { v->Visit("extent", &extent); v->Visit("kind", &kind); v->Visit("body", &body); + v->Visit("test", &test); v->Visit("thread_binding", &thread_binding); v->Visit("annotations", &annotations); v->Visit("span", &span); @@ -831,7 +835,8 @@ class ForNode : public StmtNode { bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) && - equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations); + equal(test, other->test) && equal(thread_binding, other->thread_binding) && + equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -840,6 +845,7 @@ class ForNode : public StmtNode { hash_reduce(extent); hash_reduce(kind); hash_reduce(body); + hash_reduce(test); hash_reduce(thread_binding); hash_reduce(annotations); } @@ -854,7 +860,7 @@ class ForNode : public StmtNode { */ class For : public Stmt { public: - TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, + TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional test = NullOpt, Optional thread_binding = NullOpt, Map annotations = Map(), Span span = Span()); diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index c8317500876e..bbe63ecf1c8b 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -248,6 +248,9 @@ def for_range(self, begin, end, test=None, name="i", dtype="int32", kind="serial loop_var = _expr.Var(name, dtype=dtype) extent = end if begin == 0 else (end - begin) + if test is not None: + assert kind == "serial" + def _exit_cb(): if kind == "serial": kind_id = _stmt.ForKind.SERIAL @@ -259,7 +262,7 @@ def _exit_cb(): kind_id = _stmt.ForKind.UNROLLED else: raise ValueError("Unknown kind") - self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq())) + self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), test)) return WithScope(loop_var, _exit_cb) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 9e1ef56cca58..ea53911c45e2 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -138,6 +138,7 @@ def __init__( extent, kind, body, + test=None, thread_binding=None, annotations=None, span=None, @@ -149,6 +150,7 @@ def __init__( extent, kind, body, + test, thread_binding, annotations, span, diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index e2a8553199f0..3f8c63ad3f71 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -980,7 +980,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { CodeGenLLVM::VisitStmt_(op); } else if (op->kind == ForKind::kParallel) { if (parallel_env_.penv == nullptr) { - CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, + CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, op->test, op->thread_binding, op->annotations), 0); } else { diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..a0ff74c095df 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -277,7 +277,7 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_maploop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, + return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, op->test, op->thread_binding, op->annotations); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); @@ -332,7 +332,7 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_mapextent, body); } else { return For(op->loop_var, op->min, op->extent, IterVarTypeToForKind(attr->iter_type), - op->body, op->thread_binding, op->annotations); + op->body, op->test, op->thread_binding, op->annotations); } } return StmtMutator::VisitStmt_(op); @@ -414,7 +414,7 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type); } const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return For(target->var, range->min, range->extent, kind, body, op->thread_binding, + return For(target->var, range->min, range->extent, kind, body, op->test, op->thread_binding, op->annotations); } }; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 74d1a19d2cfe..ab9d248b86be 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -968,7 +968,7 @@ class TensorCoreIRMutator : public StmtExprMutator { scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding, + stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->test, op->thread_binding, op->annotations); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 92dc38797544..ac7753d5986d 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -129,7 +129,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding, Map annotations, Span span) { + Optional test, Optional thread_binding, + Map annotations, Span span) { ICHECK(min.defined()); ICHECK(extent.defined()); ICHECK(min.dtype().is_scalar()); @@ -143,6 +144,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, node->extent = std::move(extent); node->kind = kind; node->body = std::move(body); + node->test = std::move(test); node->thread_binding = std::move(thread_binding); node->annotations = std::move(annotations); node->span = std::move(span); @@ -150,9 +152,9 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, } TVM_REGISTER_GLOBAL("tir.For").set_body_typed( - [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, + [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, Optional test, Optional thread_binding, Optional> annotations, Span span) { - return For(loop_var, min, extent, static_cast(kind), body, thread_binding, + return For(loop_var, min, extent, static_cast(kind), body, test, thread_binding, annotations.value_or(Map()), span); }); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index cbae3f95ec68..fb6f288f5f11 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -149,7 +149,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, + return For(new_var, op->min, op->extent, op->kind, op->body, op->test, op->thread_binding, op->annotations); } else { defined_.insert(v.get()); diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index dc34626205a1..d0969d6f8aa5 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -221,7 +221,7 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, - op->thread_binding, op->annotations); + op->test, op->thread_binding, op->annotations); } Stmt VisitStmt_(const AttrStmtNode* op) final { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 0b1429ca7efa..5603cb4f2061 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -444,7 +444,7 @@ class StoragePlanRewriter : public StmtExprMutator { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), + return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), op->test, op->thread_binding, op->annotations); } else { return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index c6e0b5c5f41e..f6821c71f9ae 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -125,7 +125,7 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->kind != ForKind::kUnrolled) { - return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, + return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, op->test, op->thread_binding, op->annotations); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 66f4ae329f69..c8b30e750471 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -365,7 +365,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorextent) && body.same_as(op->body)) { return GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, + return For(op->loop_var, op->min, extent, op->kind, body, op->test, op->thread_binding, op->annotations); } } diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 9770857fb0b9..f22d70b8b3a1 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -236,6 +236,7 @@ def _merge_block(slist, body): op.extent, op.kind, body, + op.test, op.thread_binding, op.annotations, ) @@ -321,7 +322,14 @@ def _do_fold(stmt): op = stmt.body assert isinstance(op, tvm.tir.For) return tvm.tir.For( - op.loop_var, op.min, 2, op.kind, op.body, op.thread_binding, op.annotations + op.loop_var, + op.min, + 2, + op.kind, + op.body, + op.test, + op.thread_binding, + op.annotations, ) return None From 8eb05b7715d583fb467ca20ee97d463c10d7cdb7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 12:05:57 +0900 Subject: [PATCH 03/16] pass test to CreateSerialFor --- src/target/llvm/codegen_cpu.cc | 9 +++++++-- src/target/llvm/codegen_llvm.cc | 9 +++++++-- src/target/llvm/codegen_llvm.h | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 3f8c63ad3f71..387baea390a4 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -994,15 +994,20 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ICHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; + llvm::Value* test = nullptr; + if (op->test) { + test = MakeValue(op->test.value()); + } if (parallel_env_.stride_pattern) { CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), - op->loop_var, op->body); + op->loop_var, op->body, test); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = min(task_id * step, op->extent); PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), - llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body, + test); } parallel_env_.in_parallel_loop = false; ++parallel_env_.parallel_loop_count; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 1dd76f6b9d51..752ca9a4f10d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -661,7 +661,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { } void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const Var& loop_var, const Stmt& body) { + const Var& loop_var, const Stmt& body, llvm::Value* test) { using llvm::BasicBlock; BasicBlock* pre_block = builder_->GetInsertBlock(); BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_); @@ -1324,8 +1324,13 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } else { ICHECK(op->kind == ForKind::kSerial); } + llvm::Value* test = nullptr; + if (op->test) { + test = MakeValue(op->test.value()); + } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body, + test); } void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 71583708da2c..9b92878062ee 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -291,7 +291,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); // Create serial for void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const Var& loop_var, const Stmt& body); + const Var& loop_var, const Stmt& body, llvm::Value* test); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); // The IRBuilder. From 204d0bae075e9751e203b1c77b35bd40ccf37a2b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 12:20:41 +0900 Subject: [PATCH 04/16] support test in for condition --- src/target/llvm/codegen_llvm.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 752ca9a4f10d..f7a3bff33829 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -673,8 +673,16 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va loop_value->addIncoming(begin, pre_block); ICHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, - md_very_likely_branch_); + + llvm::Value* cond = nullptr; + llvm::Value* less_than = CreateLT(loop_var.dtype(), loop_value, end); + if (test) { + cond = builder_->CreateAnd(less_than, test); + } else { + cond = less_than; + } + builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); + builder_->SetInsertPoint(for_body); this->VisitStmt(body); var_map_.erase(loop_var.get()); From 83438c96804e636663f58b951080cb2fc77918cd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 13:54:49 +0900 Subject: [PATCH 05/16] fix nms mod --- python/tvm/topi/cuda/nms.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 7578482bd6c7..3655edc54fa1 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -545,8 +545,11 @@ def nms_inner_loop(ib, j): # No need to do more iteration if we have already reached max_output_size # boxes # TODO(masahi): Add TIR while loop to realize early exit from the outer loop - with ib.for_range(0, nkeep, test=num_valid_boxes_local[0] < max_output_size) as j: - nms_inner_loop(ib, j) + with ib.for_range(0, nkeep) as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(num_valid_boxes_local[0] < max_output_size): + with ib.if_scope(out_scores[i, j] > -1.0): + nms_inner_loop(ib, j) with ib.else_scope(): with ib.for_range(0, nkeep) as j: From 9776dcaf1673b0b4e05dae612cba7da45757f0ac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 15:20:50 +0900 Subject: [PATCH 06/16] cpu test working --- src/printer/tir_text_printer.cc | 7 ++++++- src/target/llvm/codegen_cpu.cc | 8 ++------ src/target/llvm/codegen_llvm.cc | 14 ++++---------- src/target/llvm/codegen_llvm.h | 2 +- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 4b0871ae2ce6..d54257f775ee 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -486,7 +486,12 @@ inline const char* ForKind2String(ForKind t) { Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { Doc doc; doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " - << Print(op->min + op->extent) << ")"; + << Print(op->min + op->extent); + if (op->test) { + doc << ", (" << Print(op->test.value()) << "))"; + } else { + doc << ")"; + } if (op->kind != ForKind::kSerial) { doc << " " << Doc::StrLiteral(ForKind2String(op->kind)); } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 387baea390a4..83f90b208bf2 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -994,20 +994,16 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ICHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; - llvm::Value* test = nullptr; - if (op->test) { - test = MakeValue(op->test.value()); - } if (parallel_env_.stride_pattern) { CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), - op->loop_var, op->body, test); + op->loop_var, op->body, op->test); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = min(task_id * step, op->extent); PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body, - test); + op->test); } parallel_env_.in_parallel_loop = false; ++parallel_env_.parallel_loop_count; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f7a3bff33829..784650a0e68d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -661,7 +661,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { } void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const Var& loop_var, const Stmt& body, llvm::Value* test) { + const Var& loop_var, const Stmt& body, Optional test) { using llvm::BasicBlock; BasicBlock* pre_block = builder_->GetInsertBlock(); BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_); @@ -674,12 +674,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va ICHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - llvm::Value* cond = nullptr; llvm::Value* less_than = CreateLT(loop_var.dtype(), loop_value, end); + llvm::Value* cond = less_than; if (test) { - cond = builder_->CreateAnd(less_than, test); - } else { - cond = less_than; + cond = builder_->CreateAnd(less_than, MakeValue(test.value())); } builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); @@ -1332,13 +1330,9 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } else { ICHECK(op->kind == ForKind::kSerial); } - llvm::Value* test = nullptr; - if (op->test) { - test = MakeValue(op->test.value()); - } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body, - test); + op->test); } void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 9b92878062ee..3a94ae817d6e 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -291,7 +291,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); // Create serial for void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, - const Var& loop_var, const Stmt& body, llvm::Value* test); + const Var& loop_var, const Stmt& body, Optional test); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); // The IRBuilder. From 669ae48b8a9a873206c502e79fcbea53822da764 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 17:30:37 +0900 Subject: [PATCH 07/16] nms with early exit working --- python/tvm/topi/cuda/nms.py | 8 +++----- src/tir/ir/stmt_functor.cc | 11 +++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 3655edc54fa1..e154f2e50c12 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -544,12 +544,10 @@ def nms_inner_loop(ib, j): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we have already reached max_output_size # boxes - # TODO(masahi): Add TIR while loop to realize early exit from the outer loop - with ib.for_range(0, nkeep) as j: + with ib.for_range(0, nkeep, test=(num_valid_boxes_local[0] < max_output_size)) as j: # Proceed to the inner loop if the box j is still valid - with ib.if_scope(num_valid_boxes_local[0] < max_output_size): - with ib.if_scope(out_scores[i, j] > -1.0): - nms_inner_loop(ib, j) + with ib.if_scope(out_scores[i, j] > -1.0): + nms_inner_loop(ib, j) with ib.else_scope(): with ib.for_range(0, nkeep) as j: diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index e0ccb49fc454..9143dd580864 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -43,6 +43,9 @@ void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitExpr(op->min); this->VisitExpr(op->extent); this->VisitStmt(op->body); + if (op->test) { + this->VisitExpr(op->test.value()); + } } void StmtVisitor::VisitStmt_(const AllocateNode* op) { @@ -168,6 +171,11 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); + Optional test = NullOpt; + if (op->test) { + test = this->VisitExpr(op->test.value()); + } + if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { @@ -175,6 +183,9 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { n->min = std::move(min); n->extent = std::move(extent); n->body = std::move(body); + if (test) { + n->test = std::move(test); + } return Stmt(n); } } From c36f84e65fc638e07e5ae8f25e10cd3460a0cc62 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 17:58:23 +0900 Subject: [PATCH 08/16] support c source commit ee2363bf8131830cf0fb112890befd6be6a03f36 Author: Masahiro Masuda Date: Fri Jan 29 11:44:02 2021 +0900 enable extern lib offload for nvptx --- src/target/source/codegen_c.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index af175c7f2208..b95988a8d6e4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -891,7 +891,12 @@ void CodeGenC::VisitStmt_(const ForNode* op) { ICHECK(is_zero(op->min)); stream << "for ("; PrintType(op->loop_var.dtype(), stream); - stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; + stream << ' ' << vid << " = 0; " << vid << " < " << extent; + if (op->test) { + std::string test = PrintExpr(op->test.value()); + stream << " && (" << test << ")"; + } + stream << "; ++" << vid << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); From 1e9abdb4b9079a6c83f8b40930ac6cef6831108c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 20:44:26 +0900 Subject: [PATCH 09/16] binary search test on cpu working --- src/tir/transforms/hoist_if_then_else.cc | 4 ++ tests/python/unittest/test_tir_ir_builder.py | 67 ++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index 7bae0ce8ca75..1212a3afc267 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -142,6 +142,10 @@ class HoistCandidateSelector final : public StmtExprVisitor { HoistCandidateSelector() { InitRecorder(); } void VisitStmt_(const ForNode* op) final { + if (op->test) { + // Do not hoist if this is a while loop + return; + } // If already recording complete, // then stop tracing if (RecordingComplete()) { diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index b84ee09b9fd9..54a14460264d 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,9 +173,76 @@ def check_target(target): check_target("cuda") +def test_while_cpu(): + n = 1024 + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + + def searchsorted_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, n, name="i", kind="parallel") as i: + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = n - 1 + v = Bptr[i] + num_loop = int(np.log2(n)) + 1 + + with ib.for_range(0, num_loop, test=(lo[0] <= hi[0])) as _: + mid = tvm.tir.floordiv(lo[0] + hi[0], 2).astype("int32") + with ib.if_scope(Aptr[mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + with ib.if_scope(Aptr[mid] > v): + hi[0] = mid - 1 + + Cptr[i] = lo[0] + + body = ib.get() + + return body + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]), + name="searchsorted_ir", + dtype="int32", + ) + s = te.create_schedule(C.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + # build and invoke the kernel. + with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): + func = tvm.build(s, [A, B, C], target) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + ctx = tvm.context(target, 0) + # launch the kernel. + a_np = np.random.uniform(size=n).astype(A.dtype) + b_np = np.random.uniform(size=n).astype(B.dtype) + a_np = np.sort(a_np) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(a, b, c) + ref = np.searchsorted(a_np, b_np) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("llvm") + + if __name__ == "__main__": test_prefetch() test_if() test_for() test_cpu() test_gpu() + test_while_cpu() From 721cc1821a0af8205278ad8316a49512be973f12 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 20:54:50 +0900 Subject: [PATCH 10/16] gpu version working --- tests/python/unittest/test_tir_ir_builder.py | 87 ++++++++++++++++++-- 1 file changed, 81 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 54a14460264d..d49bf9e42bfd 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -239,10 +239,85 @@ def check_target(target): check_target("llvm") +def test_while_gpu(): + n = 1024 + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) + idxd = tvm.tir.indexdiv + + def searchsorted_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = 32 + ib.scope_attr(bx, "thread_extent", idxd(n + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < n): + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = n - 1 + v = Bptr[tid] + num_loop = int(np.log2(n)) + 1 + + with ib.for_range(0, num_loop, test=(lo[0] <= hi[0])) as _: + mid = tvm.tir.floordiv(lo[0] + hi[0], 2).astype("int32") + with ib.if_scope(Aptr[mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + with ib.if_scope(Aptr[mid] > v): + hi[0] = mid - 1 + + Cptr[tid] = lo[0] + + body = ib.get() + + return body + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]), + name="searchsorted_ir", + dtype="int32", + ) + s = te.create_schedule(C.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + # build and invoke the kernel. + with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): + func = tvm.build(s, [A, B, C], target) + ctx = tvm.context(target, 0) + # launch the kernel. + a_np = np.random.uniform(size=n).astype(A.dtype) + b_np = np.random.uniform(size=n).astype(B.dtype) + a_np = np.sort(a_np) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(a, b, c) + ref = np.searchsorted(a_np, b_np) + tvm.testing.assert_allclose(c.asnumpy(), ref) + + check_target("cuda") + check_target("nvptx") + + if __name__ == "__main__": - test_prefetch() - test_if() - test_for() - test_cpu() - test_gpu() - test_while_cpu() + # test_prefetch() + # test_if() + # test_for() + # test_cpu() + # test_gpu() + # test_while_cpu() + test_while_gpu() From 23b1ee3d68ff4fca26a4c496bdbfc22e8527db08 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 21:09:18 +0900 Subject: [PATCH 11/16] clean up test --- tests/python/unittest/test_tir_ir_builder.py | 151 ++++++------------- 1 file changed, 50 insertions(+), 101 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index d49bf9e42bfd..be2334bab5d7 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -173,80 +173,40 @@ def check_target(target): check_target("cuda") -def test_while_cpu(): - n = 1024 - dtype = "float32" - A = te.placeholder((n,), name="A", dtype=dtype) - B = te.placeholder((n,), name="B", dtype=dtype) - - def searchsorted_ir(A, B, C): +def test_binary_search(): + def binary_search(ib, n, i, Aptr, Bptr, Cptr): + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = n - 1 + v = Bptr[i] + num_loop = int(np.log2(n)) + 1 + + with ib.for_range(0, num_loop, test=(lo[0] <= hi[0])) as _: + mid = tvm.tir.floordiv(lo[0] + hi[0], 2).astype("int32") + with ib.if_scope(Aptr[mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + with ib.if_scope(Aptr[mid] > v): + hi[0] = mid - 1 + + Cptr[i] = lo[0] + + def searchsorted_ir_cpu(A, B, C, n): ib = tvm.tir.ir_builder.create() Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) Cptr = ib.buffer_ptr(C) with ib.for_range(0, n, name="i", kind="parallel") as i: - lo = ib.allocate("int32", (1,), name="lo", scope="local") - hi = ib.allocate("int32", (1,), name="hi", scope="local") - - lo[0] = 0 - hi[0] = n - 1 - v = Bptr[i] - num_loop = int(np.log2(n)) + 1 - - with ib.for_range(0, num_loop, test=(lo[0] <= hi[0])) as _: - mid = tvm.tir.floordiv(lo[0] + hi[0], 2).astype("int32") - with ib.if_scope(Aptr[mid] < v): - lo[0] = mid + 1 - with ib.else_scope(): - with ib.if_scope(Aptr[mid] > v): - hi[0] = mid - 1 - - Cptr[i] = lo[0] + binary_search(ib, n, i, Aptr, Bptr, Cptr) body = ib.get() return body - C = te.extern( - A.shape, - [A, B], - lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]), - name="searchsorted_ir", - dtype="int32", - ) - s = te.create_schedule(C.op) - - def check_target(target): - if not tvm.testing.device_enabled(target): - return - # build and invoke the kernel. - with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): - func = tvm.build(s, [A, B, C], target) - print(tvm.lower(s, [A, B, C], simple_mode=True)) - ctx = tvm.context(target, 0) - # launch the kernel. - a_np = np.random.uniform(size=n).astype(A.dtype) - b_np = np.random.uniform(size=n).astype(B.dtype) - a_np = np.sort(a_np) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) - func(a, b, c) - ref = np.searchsorted(a_np, b_np) - tvm.testing.assert_allclose(c.asnumpy(), ref) - - check_target("llvm") - - -def test_while_gpu(): - n = 1024 - dtype = "float32" - A = te.placeholder((n,), name="A", dtype=dtype) - B = te.placeholder((n,), name="B", dtype=dtype) - idxd = tvm.tir.indexdiv - - def searchsorted_ir(A, B, C): + def searchsorted_ir_gpu(A, B, C, n): ib = tvm.tir.ir_builder.create() Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) @@ -255,50 +215,39 @@ def searchsorted_ir(A, B, C): bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") max_threads = 32 - ib.scope_attr(bx, "thread_extent", idxd(n + max_threads - 1, max_threads)) + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads)) ib.scope_attr(tx, "thread_extent", max_threads) tid = bx * max_threads + tx with ib.if_scope(tid < n): - lo = ib.allocate("int32", (1,), name="lo", scope="local") - hi = ib.allocate("int32", (1,), name="hi", scope="local") - - lo[0] = 0 - hi[0] = n - 1 - v = Bptr[tid] - num_loop = int(np.log2(n)) + 1 - - with ib.for_range(0, num_loop, test=(lo[0] <= hi[0])) as _: - mid = tvm.tir.floordiv(lo[0] + hi[0], 2).astype("int32") - with ib.if_scope(Aptr[mid] < v): - lo[0] = mid + 1 - with ib.else_scope(): - with ib.if_scope(Aptr[mid] > v): - hi[0] = mid - 1 - - Cptr[tid] = lo[0] + binary_search(ib, n, tid, Aptr, Bptr, Cptr) body = ib.get() return body - C = te.extern( - A.shape, - [A, B], - lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]), - name="searchsorted_ir", - dtype="int32", - ) - s = te.create_schedule(C.op) + n = 1024 + dtype = "float32" + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.placeholder((n,), name="B", dtype=dtype) - def check_target(target): + def check_target(target, ir): if not tvm.testing.device_enabled(target): return - # build and invoke the kernel. + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: ir(ins[0], ins[1], outs[0], n), + name="searchsorted_ir", + dtype="int32", + ) + s = te.create_schedule(C.op) + with tvm.transform.PassContext(opt_level=3, disabled_pass=["HoistIfThenElse"]): func = tvm.build(s, [A, B, C], target) + ctx = tvm.context(target, 0) - # launch the kernel. a_np = np.random.uniform(size=n).astype(A.dtype) b_np = np.random.uniform(size=n).astype(B.dtype) a_np = np.sort(a_np) @@ -309,15 +258,15 @@ def check_target(target): ref = np.searchsorted(a_np, b_np) tvm.testing.assert_allclose(c.asnumpy(), ref) - check_target("cuda") - check_target("nvptx") + check_target("llvm", searchsorted_ir_cpu) + check_target("cuda", searchsorted_ir_gpu) + check_target("nvptx", searchsorted_ir_gpu) if __name__ == "__main__": - # test_prefetch() - # test_if() - # test_for() - # test_cpu() - # test_gpu() - # test_while_cpu() - test_while_gpu() + test_prefetch() + test_if() + test_for() + test_cpu() + test_gpu() + test_binary_search() From 8e62293677e5b06cad780f1a71a1a3c5e13e80d5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Jan 2021 22:09:28 +0900 Subject: [PATCH 12/16] fix corner case --- tests/python/unittest/test_tir_ir_builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index be2334bab5d7..881d4bbba58d 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -190,6 +190,10 @@ def binary_search(ib, n, i, Aptr, Bptr, Cptr): with ib.else_scope(): with ib.if_scope(Aptr[mid] > v): hi[0] = mid - 1 + with ib.else_scope(): + # force loop to terminate + lo[0] = mid + hi[0] = mid - 1 Cptr[i] = lo[0] From a275666d75ccf73b49b2a5a1b068f1af82a2d5d8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 1 Feb 2021 07:19:03 +0900 Subject: [PATCH 13/16] improve binary search --- tests/python/unittest/test_tir_ir_builder.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 881d4bbba58d..dd01cec0d0de 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -179,21 +179,16 @@ def binary_search(ib, n, i, Aptr, Bptr, Cptr): hi = ib.allocate("int32", (1,), name="hi", scope="local") lo[0] = 0 - hi[0] = n - 1 + hi[0] = n v = Bptr[i] num_loop = int(np.log2(n)) + 1 - with ib.for_range(0, num_loop, test=(lo[0] <= hi[0])) as _: - mid = tvm.tir.floordiv(lo[0] + hi[0], 2).astype("int32") + with ib.for_range(0, num_loop, test=(lo[0] < hi[0])) as _: + mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32") with ib.if_scope(Aptr[mid] < v): lo[0] = mid + 1 with ib.else_scope(): - with ib.if_scope(Aptr[mid] > v): - hi[0] = mid - 1 - with ib.else_scope(): - # force loop to terminate - lo[0] = mid - hi[0] = mid - 1 + hi[0] = mid Cptr[i] = lo[0] From 013f1291f1d144eac6985994d2ffcb5afa42318d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 1 Feb 2021 07:29:22 +0900 Subject: [PATCH 14/16] fix lint --- include/tvm/tir/stmt.h | 4 ++-- src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4ddc9c0adf34..0d0241e5625d 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -860,8 +860,8 @@ class ForNode : public StmtNode { */ class For : public Stmt { public: - TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional test = NullOpt, - Optional thread_binding = NullOpt, + TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, + Optional test = NullOpt, Optional thread_binding = NullOpt, Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index ab9d248b86be..6e074f624cf7 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -968,8 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator { scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->test, op->thread_binding, - op->annotations); + stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->test, + op->thread_binding, op->annotations); } } return stmt; From a1b2c4a57e136186165e54bb1148faa44ad0a899 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 2 Feb 2021 04:15:54 +0900 Subject: [PATCH 15/16] add comment --- include/tvm/tir/stmt.h | 2 +- python/tvm/tir/ir_builder.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0d0241e5625d..e497407a5877 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -802,7 +802,7 @@ class ForNode : public StmtNode { ForKind kind; /*! \brief The body of the for loop. */ Stmt body; - /*! \brief The test condition of the for loop. */ + /*! \brief The additional termination condition of the for loop. */ Optional test; /*! diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index bbe63ecf1c8b..b29822642c14 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -217,6 +217,9 @@ def for_range(self, begin, end, test=None, name="i", dtype="int32", kind="serial end : Expr The end iteration scope + test : Expr, optional + The additional termination condition. + name : str, optional The name of iteration variable, if no input names, using typical index names i, j, k, then i_nidx @@ -249,7 +252,8 @@ def for_range(self, begin, end, test=None, name="i", dtype="int32", kind="serial extent = end if begin == 0 else (end - begin) if test is not None: - assert kind == "serial" + msg = "A general termination condition is only supported for a serial loop." + assert kind == "serial", msg def _exit_cb(): if kind == "serial": From 07824b0cc01fbd6e971ecd419130047d1c125e1e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 2 Feb 2021 05:51:05 +0900 Subject: [PATCH 16/16] swap arguments order to avoid breaking existing tests --- python/tvm/tir/ir_builder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index b29822642c14..e6a2c4af0bc4 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value): value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) - def for_range(self, begin, end, test=None, name="i", dtype="int32", kind="serial"): + def for_range(self, begin, end, name="i", test=None, dtype="int32", kind="serial"): """Create a for iteration scope. Parameters @@ -217,13 +217,13 @@ def for_range(self, begin, end, test=None, name="i", dtype="int32", kind="serial end : Expr The end iteration scope - test : Expr, optional - The additional termination condition. - name : str, optional The name of iteration variable, if no input names, using typical index names i, j, k, then i_nidx + test : Expr, optional + The additional termination condition. + dtype : str, optional The data type of iteration variable.