diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 215a6c9c5b1b..52f1c20134f2 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -647,6 +647,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + CHECK_EQ(op->args[0].type().lanes(), 1) + << "if_then_else can only take scalar condition"; using llvm::BasicBlock; BasicBlock* then_block = BasicBlock::Create( *ctx_, "if_then", function_); diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 19874a803657..282f1eee1399 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -83,6 +83,19 @@ class Vectorizer : public IRMutator { // user mutate from parent. using IRMutator::Mutate; + Stmt Mutate(Stmt stmt) final { + CHECK(!need_scalarize_); + + Stmt ret = IRMutator::Mutate(stmt); + if (need_scalarize_) { + need_scalarize_ = false; + return Scalarize(stmt); + } else { + return ret; + } + } + + Expr Mutate_(const Add* op, const Expr &e) final { return AddSubVec(op, e); } @@ -200,10 +213,37 @@ class Vectorizer : public IRMutator { return e; } } + // IfThenElse expr + Expr MutateIfThenElseExpr_(const Call *op, const Expr& e) { + Expr cond = this->Mutate(op->args[0]); + if (cond.type().is_vector()) { + need_scalarize_ = true; + return e; + } + Expr t = this->Mutate(op->args[1]); + Expr f = this->Mutate(op->args[2]); + if (cond.same_as(op->args[0]) && + t.same_as(op->args[1]) && + f.same_as(op->args[2])) { + return e; + } else { + int lanes = std::max(t.type().lanes(), f.type().lanes()); + t = BroadcastTo(t, lanes); + f = BroadcastTo(f, lanes); + return Call::make( + op->type.with_lanes(lanes), op->name, + {cond, t, f}, op->call_type, op->func, op->value_index); + } + } // Call Expr Mutate_(const Call* op, const Expr& e) final { + if (op->name == intrinsic::tvm_if_then_else) { + return MutateIfThenElseExpr_(op, e); + } int lane = 0; Array new_args = MutateArray(op->args, &lane); + + // normal code path. if (op->args.same_as(new_args)) { return e; } else { @@ -367,6 +407,8 @@ class Vectorizer : public IRMutator { int var_lanes_; // ramp representing the var. Expr ramp_; + // flag to mark requirment of scalarization. + bool need_scalarize_{false}; // The lets std::unordered_map lets_; // mutate array, with given lane requirement diff --git a/tests/python/unittest/test_pass_vectorize.py b/tests/python/unittest/test_pass_vectorize.py index 45bb4362d68e..1fbcc655ac80 100644 --- a/tests/python/unittest/test_pass_vectorize.py +++ b/tests/python/unittest/test_pass_vectorize.py @@ -53,7 +53,36 @@ def test_vectorize_with_if(): assert stmt.then_case.value.dtype == "float32x4" assert isinstance(stmt.else_case, tvm.stmt.For) +def test_vectorize_if_then_else(): + n = tvm.var('n') + x = tvm.var('x') + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, 4, for_type="vectorize") as i: + A[i] = tvm.call_intrin("float32", "tvm_if_then_else", + i > 0, + A[i] + 1, A[i]) + stmt = ib.get() + stmt = tvm.ir_pass.VectorizeLoop(stmt) + assert isinstance(stmt, tvm.stmt.For) + + + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, n) as k: + with ib.for_range(0, 4, for_type="vectorize") as i: + A[k * 4 + i] = tvm.call_intrin("float32", "tvm_if_then_else", + k > 0, + A[k * 4 + i], 0) + stmt = ib.get() + assert isinstance(stmt.body, tvm.stmt.For) + stmt = tvm.ir_pass.VectorizeLoop(stmt) + assert not isinstance(stmt.body, tvm.stmt.For) + assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast) + + if __name__ == "__main__": test_vectorize_vector() test_vectorize_with_if() test_vectorize_loop() + test_vectorize_if_then_else()