Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ class IfThenElseNode : public StmtNode {
/*! \brief The branch to be executed when condition is true. */
Stmt then_case;
/*! \brief The branch to be executed when condition is false, can be null. */
Stmt else_case;
Optional<Stmt> else_case;

void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition);
Expand Down Expand Up @@ -854,7 +854,7 @@ class IfThenElseNode : public StmtNode {
*/
class IfThenElse : public Stmt {
public:
TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt(),
TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
Expand Down
12 changes: 5 additions & 7 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,19 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
}
}

Stmt then_case, else_case;
Stmt then_case;
Optional<Stmt> else_case;
{
With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition)));
else_case = this->VisitStmt(op->else_case);
else_case = this->VisitStmt(op->else_case.value());
}
if (is_one(real_condition)) return then_case;
if (is_zero(real_condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate(0);
return else_case.value_or(Evaluate(0));
}

if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
Expand Down
4 changes: 2 additions & 2 deletions src/arith/ir_visitor_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
With<ConstraintContext> constraint(&analyzer_, real_condition);
this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
With<ConstraintContext> constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition)));
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,11 @@ void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) {
PrintStmt(op->then_case);
indent_ -= tab_;

if (!is_noop(op->else_case)) {
if (op->else_case && !is_noop(op->else_case.value())) {
PrintIndent();
stream << "else:\n";
indent_ += tab_;
PrintStmt(op->else_case);
PrintStmt(op->else_case.value());
indent_ -= tab_;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << PrintBody(op->then_case);
if (!is_one(op->condition) && op->else_case.defined()) {
doc << " else" << PrintBody(op->else_case);
if (!is_one(op->condition) && op->else_case) {
doc << " else" << PrintBody(op->else_case.value());
}
return doc;
}
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,9 +1244,9 @@ Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
if (!is_one(op->condition) && op->else_case.defined()) {
if (!is_one(op->condition) && op->else_case) {
doc << Doc::NewLine();
doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case));
doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case.value()));
}
return doc;
}
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1759,14 +1759,14 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_);
auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_);
if (op->else_case.defined()) {
if (op->else_case) {
auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_);
builder_->CreateCondBr(cond, then_block, else_block);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block);
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
builder_->CreateBr(end_block);
} else {
builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,11 +936,11 @@ void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
PrintStmt(op->then_case);
this->EndScope(then_scope);

if (op->else_case.defined()) {
if (op->else_case) {
PrintIndent();
stream << "} else {\n";
int else_scope = BeginScope();
PrintStmt(op->else_case);
PrintStmt(op->else_case.value());
this->EndScope(else_scope);
}
PrintIndent();
Expand Down
4 changes: 2 additions & 2 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
spirv::Value cond = MakeValue(op->condition);
spirv::Label then_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
if (op->else_case.defined()) {
if (op->else_case) {
spirv::Label else_label = builder_->NewLabel();
builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
Expand All @@ -638,7 +638,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
builder_->MakeInst(spv::OpBranch, merge_label);
// else block
builder_->StartLabel(else_label);
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
builder_->MakeInst(spv::OpBranch, merge_label);
} else {
builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
Expand Down
4 changes: 2 additions & 2 deletions src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,13 @@ void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) {
int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
this->PushOp(StackVM::POP);
this->Push(op->then_case);
if (op->else_case.defined()) {
if (op->else_case) {
int64_t label_then_jump = this->GetPC();
int64_t then_jump = this->PushOp(StackVM::RJUMP, 0);
int64_t else_begin = this->GetPC();
this->SetOperand(else_jump, else_begin - label_ejump);
this->PushOp(StackVM::POP);
this->Push(op->else_case);
this->Push(op->else_case.value());
int64_t if_end = this->GetPC();
this->SetOperand(then_jump, if_end - label_then_jump);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/tir/analysis/estimate_flops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,

TResult VisitStmt_(const IfThenElseNode* branch) override {
TResult cond = VisitExpr(branch->condition);
if (branch->else_case.defined()) {
cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case));
if (branch->else_case) {
cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case.value()));
} else {
cond += VisitStmt(branch->then_case);
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

// IfThenElse
IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case, Span span) {
IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case, Span span) {
ICHECK(condition.defined());
ICHECK(then_case.defined());
// else_case may be null.
Expand Down Expand Up @@ -670,7 +670,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(op->then_case);
p->indent -= 2;

if (!op->else_case.defined()) {
if (!op->else_case) {
break;
}

Expand Down
10 changes: 5 additions & 5 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->then_case);
if (op->else_case.defined()) {
this->VisitStmt(op->else_case);
if (op->else_case) {
this->VisitStmt(op->else_case.value());
}
}

Expand Down Expand Up @@ -352,9 +352,9 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->VisitStmt(op->else_case);
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/common_subexpr_elim_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) {
table_of_computations_.clear();

ComputationTable computations_done_by_else;
if (op->else_case.defined()) {
// And finally calls the VisitStmt() method on the `then_case` child
VisitStmt(op->else_case);
if (op->else_case) {
// And finally calls the VisitStmt() method on the `else_case` child
VisitStmt(op->else_case.value());
computations_done_by_else = table_of_computations_;
table_of_computations_.clear();
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/coproc_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ class CoProcInstDepDetector : public StmtVisitor {
first_state_.clear();
last_state_.clear();
}
if (op->else_case.defined()) {
this->VisitStmt(op->else_case);
if (op->else_case) {
this->VisitStmt(op->else_case.value());
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ class VTInjector : public arith::IRMutatorWithAnalyzer {
visit_touched_var_ = false;
ICHECK_EQ(max_loop_depth_, 0);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
else_case = this->VisitStmt(op->else_case);
else_case = this->VisitStmt(op->else_case.value());
max_loop_depth_ = std::max(temp, max_loop_depth_);
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
} else if (const auto* ite = s.as<IfThenElseNode>()) {
auto n = make_object<IfThenElseNode>(*ite);
ICHECK(is_no_op(n->then_case));
ICHECK(!n->else_case.defined());
ICHECK(!n->else_case);
n->then_case = body;
body = Stmt(n);
} else if (const auto* seq = s.as<SeqStmtNode>()) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/lift_attr_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ class AttrScopeLifter : public StmtMutator {
}

Stmt VisitStmt_(const IfThenElseNode* op) final {
if (!op->else_case.defined()) {
if (!op->else_case) {
return StmtMutator::VisitStmt_(op);
}
Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node;
PrimExpr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->VisitStmt(op->else_case);
Stmt else_case = this->VisitStmt(op->else_case.value());
if (attr_node_.defined() && attr_value_.defined() && first_node.defined() &&
first_value.defined() && attr_node_.same_as(first_node) &&
ValueSame(attr_value_, first_value)) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/profile_instrumentation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ class LoopAnalyzer : public StmtExprVisitor {
} else if (stmt->IsInstance<IfThenElseNode>()) {
const IfThenElseNode* n = stmt.as<IfThenElseNode>();
unsigned height = TraverseLoop(n->then_case, parent_depth, has_parallel);
if (n->else_case.defined()) {
height = std::max(height, TraverseLoop(n->else_case, parent_depth, has_parallel));
if (n->else_case) {
height = std::max(height, TraverseLoop(n->else_case.value(), parent_depth, has_parallel));
}
return height;
} else if (stmt->IsInstance<ForNode>()) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/remove_no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class NoOpRemover : public StmtMutator {
Stmt VisitStmt_(const IfThenElseNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<IfThenElseNode>();
if (op->else_case.defined()) {
if (is_no_op(op->else_case)) {
if (op->else_case) {
if (is_no_op(op->else_case.value())) {
if (is_no_op(op->then_case)) {
return MakeEvaluate(op->condition);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
if (const int64_t* as_int = as_const_int(cond)) {
if (*as_int) {
return this->VisitStmt(op->then_case);
} else if (op->else_case.defined()) {
return this->VisitStmt(op->else_case);
} else if (op->else_case) {
return this->VisitStmt(op->else_case.value());
} else {
return Evaluate(0);
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/storage_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (op->else_case.defined()) {
if (op->else_case) {
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,9 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->VisitStmt(op->else_case);
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
Expand Down