Skip to content
Closed
10 changes: 8 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,9 @@ class ForNode : public StmtNode {
ForKind kind;
/*! \brief The body of the for loop. */
Stmt body;
/*! \brief The additional termination condition of the for loop. */
Optional<PrimExpr> test;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to have a RFC discussion, since different strategies changes to the IR can have different implications


/*!
* \brief Only valid when kind == ForKind::kThreadBinding
* The context thread that this loop variable bounds to.
Expand All @@ -823,6 +826,7 @@ class ForNode : public StmtNode {
v->Visit("extent", &extent);
v->Visit("kind", &kind);
v->Visit("body", &body);
v->Visit("test", &test);
v->Visit("thread_binding", &thread_binding);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
Expand All @@ -831,7 +835,8 @@ class ForNode : public StmtNode {
bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
equal(test, other->test) && equal(thread_binding, other->thread_binding) &&
equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -840,6 +845,7 @@ class ForNode : public StmtNode {
hash_reduce(extent);
hash_reduce(kind);
hash_reduce(body);
hash_reduce(test);
hash_reduce(thread_binding);
hash_reduce(annotations);
}
Expand All @@ -855,7 +861,7 @@ class ForNode : public StmtNode {
class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding = NullOpt,
Optional<PrimExpr> test = NullOpt, Optional<IterVar> thread_binding = NullOpt,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
def for_range(self, begin, end, name="i", test=None, dtype="int32", kind="serial"):
"""Create a for iteration scope.

Parameters
Expand All @@ -221,6 +221,9 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
The name of iteration variable, if no input names,
using typical index names i, j, k, then i_nidx

test : Expr, optional
The additional termination condition.

dtype : str, optional
The data type of iteration variable.

Expand Down Expand Up @@ -248,6 +251,10 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else (end - begin)

if test is not None:
msg = "A general termination condition is only supported for a serial loop."
assert kind == "serial", msg

def _exit_cb():
if kind == "serial":
kind_id = _stmt.ForKind.SERIAL
Expand All @@ -259,7 +266,7 @@ def _exit_cb():
kind_id = _stmt.ForKind.UNROLLED
else:
raise ValueError("Unknown kind")
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq()))
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), test))

return WithScope(loop_var, _exit_cb)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
extent,
kind,
body,
test=None,
thread_binding=None,
annotations=None,
span=None,
Expand All @@ -149,6 +150,7 @@ def __init__(
extent,
kind,
body,
test,
thread_binding,
annotations,
span,
Expand Down
22 changes: 12 additions & 10 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,16 +541,18 @@ def nms_inner_loop(ib, j):

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size
# boxes
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size
# boxes
with ib.for_range(0, nkeep, test=(num_valid_boxes_local[0] < max_output_size)) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
nms_inner_loop(ib, j)

with ib.else_scope():
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
nms_inner_loop(ib, j)

with ib.if_scope(tx + 0 == 0):
Expand Down
7 changes: 6 additions & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,12 @@ inline const char* ForKind2String(ForKind t) {
Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
<< Print(op->min + op->extent) << ")";
<< Print(op->min + op->extent);
if (op->test) {
doc << ", (" << Print(op->test.value()) << "))";
} else {
doc << ")";
}
if (op->kind != ForKind::kSerial) {
doc << " " << Doc::StrLiteral(ForKind2String(op->kind));
}
Expand Down
7 changes: 4 additions & 3 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
CodeGenLLVM::VisitStmt_(op);
} else if (op->kind == ForKind::kParallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, op->test,
op->thread_binding, op->annotations),
0);
} else {
Expand All @@ -996,13 +996,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
parallel_env_.in_parallel_loop = true;
if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task),
op->loop_var, op->body);
op->loop_var, op->body, op->test);
} else {
PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
PrimExpr begin = min(task_id * step, op->extent);
PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin), MakeValue(end),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body);
llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body,
op->test);
}
parallel_env_.in_parallel_loop = false;
++parallel_env_.parallel_loop_count;
Expand Down
15 changes: 11 additions & 4 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
}

void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body) {
const Var& loop_var, const Stmt& body, Optional<PrimExpr> test) {
using llvm::BasicBlock;
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_);
Expand All @@ -673,8 +673,14 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va
loop_value->addIncoming(begin, pre_block);
ICHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end,
md_very_likely_branch_);

llvm::Value* less_than = CreateLT(loop_var.dtype(), loop_value, end);
llvm::Value* cond = less_than;
if (test) {
cond = builder_->CreateAnd(less_than, MakeValue(test.value()));
}
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);

builder_->SetInsertPoint(for_body);
this->VisitStmt(body);
var_map_.erase(loop_var.get());
Expand Down Expand Up @@ -1325,7 +1331,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
ICHECK(op->kind == ForKind::kSerial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body,
op->test);
}

void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
// Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body);
const Var& loop_var, const Stmt& body, Optional<PrimExpr> test);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index);
// The IRBuilder.
Expand Down
7 changes: 6 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,12 @@ void CodeGenC::VisitStmt_(const ForNode* op) {
ICHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n";
stream << ' ' << vid << " = 0; " << vid << " < " << extent;
if (op->test) {
std::string test = PrintExpr(op->test.value());
stream << " && (" << test << ")";
}
stream << "; ++" << vid << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
Expand Down
6 changes: 3 additions & 3 deletions src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = tir::Substitute(body, rmap);
under_outer = false;
return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body,
return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, op->test,
op->thread_binding, op->annotations);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
Expand Down Expand Up @@ -332,7 +332,7 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar,
return AttrStmt(iter_var, "thread_extent", op->extent, body);
} else {
return For(op->loop_var, op->min, op->extent, IterVarTypeToForKind(attr->iter_type),
op->body, op->thread_binding, op->annotations);
op->body, op->test, op->thread_binding, op->annotations);
}
}
return StmtMutator::VisitStmt_(op);
Expand Down Expand Up @@ -414,7 +414,7 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>
kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type);
}
const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
return For(target->var, range->min, range->extent, kind, body, op->thread_binding,
return For(target->var, range->min, range->extent, kind, body, op->test, op->thread_binding,
op->annotations);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
scaled_extent_value = ori_extent_value / scale_factor;
}
PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding,
op->annotations);
stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->test,
op->thread_binding, op->annotations);
}
}
return stmt;
Expand Down
8 changes: 5 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

// For
For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding, Map<String, ObjectRef> annotations, Span span) {
Optional<PrimExpr> test, Optional<IterVar> thread_binding,
Map<String, ObjectRef> annotations, Span span) {
ICHECK(min.defined());
ICHECK(extent.defined());
ICHECK(min.dtype().is_scalar());
Expand All @@ -143,16 +144,17 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
node->extent = std::move(extent);
node->kind = kind;
node->body = std::move(body);
node->test = std::move(test);
node->thread_binding = std::move(thread_binding);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.For").set_body_typed(
[](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body,
[](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, Optional<PrimExpr> test,
Optional<IterVar> thread_binding, Optional<Map<String, ObjectRef>> annotations, Span span) {
return For(loop_var, min, extent, static_cast<ForKind>(kind), body, thread_binding,
return For(loop_var, min, extent, static_cast<ForKind>(kind), body, test, thread_binding,
annotations.value_or(Map<String, ObjectRef>()), span);
});

Expand Down
11 changes: 11 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ void StmtVisitor::VisitStmt_(const ForNode* op) {
this->VisitExpr(op->min);
this->VisitExpr(op->extent);
this->VisitStmt(op->body);
if (op->test) {
this->VisitExpr(op->test.value());
}
}

void StmtVisitor::VisitStmt_(const AllocateNode* op) {
Expand Down Expand Up @@ -168,13 +171,21 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) {
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
Stmt body = this->VisitStmt(op->body);
Optional<PrimExpr> test = NullOpt;
if (op->test) {
test = this->VisitExpr(op->test.value());
}

if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->min = std::move(min);
n->extent = std::move(extent);
n->body = std::move(body);
if (test) {
n->test = std::move(test);
}
return Stmt(n);
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ class HoistCandidateSelector final : public StmtExprVisitor {
HoistCandidateSelector() { InitRecorder(); }

void VisitStmt_(const ForNode* op) final {
if (op->test) {
// Do not hoist if this is a while loop
return;
}
// If already recording complete,
// then stop tracing
if (RecordingComplete()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class IRConvertSSA final : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<ForNode>();
return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding,
return For(new_var, op->min, op->extent, op->kind, op->body, op->test, op->thread_binding,
op->annotations);
} else {
defined_.insert(v.get());
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class DataTypeRewriter : public StmtExprMutator {
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body,
op->thread_binding, op->annotations);
op->test, op->thread_binding, op->annotations);
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class StoragePlanRewriter : public StmtExprMutator {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body),
return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), op->test,
op->thread_binding, op->annotations);
} else {
return StmtExprMutator::VisitStmt_(op);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class LoopUnroller : public StmtExprMutator {
} else {
if (auto_unroll) {
if (op->kind != ForKind::kUnrolled) {
return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body,
return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, op->test,
op->thread_binding, op->annotations);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding,
return For(op->loop_var, op->min, extent, op->kind, body, op->test, op->thread_binding,
op->annotations);
}
}
Expand Down
Loading