From 8695e7afef6770b4b5085c8107711aedd9bd41ae Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Jun 2022 09:10:18 -0500 Subject: [PATCH 01/10] [TIR][Arith] Use non-inlined bindings when proving conditional --- src/tir/transforms/simplify.cc | 20 +++++++++++++++++ .../unittest/test_tir_transform_simplify.py | 22 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 85f405be447a..15649dfe790c 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -70,6 +70,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { // because the call to simplify will always inline the var. analyzer_->Bind(op->var, value); return this->VisitStmt(op->body); + } else if (SideEffect(op->value) <= CallEffectKind::kPure) { + // Even if we aren't replacing all occurrences, they may be + // necessary for proving conditional statements. + non_inlined_bindings_.Set(op->var, value); } Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { @@ -82,6 +86,20 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } } + Stmt VisitStmt_(const IfThenElseNode* op) { + PrimExpr cond = analyzer_->Simplify(Substitute(op->condition, non_inlined_bindings_)); + if (const int64_t* as_int = as_const_int(cond)) { + if (*as_int) { + return this->VisitStmt(op->then_case); + } else if (op->else_case.defined()) { + return this->VisitStmt(op->else_case); + } else { + return Evaluate(0); + } + } + return Parent::VisitStmt_(op); + } + Stmt VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; return Stmt(); @@ -114,6 +132,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } return true; } + + Map non_inlined_bindings_; }; } // namespace arith diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 4f727cd89b12..b24714c03f64 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -391,5 +391,27 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): A[i, j] = 2 +class TestProveConditionUsingLet(BaseBeforeAfter): + """Simplify conditions using non-inlined let bindings + + Not all let bindings are inlined when they occur in later + expressions. However, even if they are not inlined, they may be + used to prove the value of a condition. + """ + + @T.prim_func + def before(A: T.Buffer[4, "bool"]): + for i in T.serial(4): + condition = i < 3 + if condition or i >= 3: + A[i] = condition + + @T.prim_func + def expected(A: T.Buffer[4, "bool"]): + for i in T.serial(4): + condition = i < 3 + A[i] = condition + + if __name__ == "__main__": tvm.testing.main() From 6bcffd485a186b756169c304a84e91bb252a479e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Jun 2022 09:49:31 -0500 Subject: [PATCH 02/10] [TIR][Arith] Recognize Var when used as a literal constraint --- src/arith/rewrite_simplify.cc | 8 +++- .../unittest/test_tir_transform_simplify.py | 47 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a168e1f0836c..b1c4ffc3e9cb 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -228,7 +228,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { size_t old_literal_size = literal_constraints_.size(); // we will compare the already simplified result with the constraint, - // so simplify the constarint as well + // so simplify the constraint as well PrimExpr new_constraint = operator()(constraint); for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) { if (SideEffect(subconstraint) <= CallEffectKind::kPure) { @@ -1652,6 +1652,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = GetRef(op); + if (op->dtype == DataType::Bool()) { + if (auto match = TryMatchLiteralConstraint(var)) { + return match.value(); + } + } + auto it = var_map_.find(var); if (it != var_map_.end()) { return it->second; diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index b24714c03f64..6057de9b4678 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -413,5 +413,52 @@ def expected(A: T.Buffer[4, "bool"]): A[i] = condition +class TestProveLetCondition(BaseBeforeAfter): + """Simplify conditions using non-inlined let bindings + + Not all let bindings are inlined when they occur in later + expressions. However, even if they are not inlined, they may be + used to prove the value of a condition. + """ + + @T.prim_func + def before(A: T.Buffer[4, "bool"]): + for i in T.serial(4): + condition = i < 3 + if i < 3: + if condition: + A[i] = condition + + @T.prim_func + def expected(A: T.Buffer[4, "bool"]): + for i in T.serial(4): + condition = i < 3 + if i < 3: + A[i] = condition + + +class TestProveRepeatedLetCondition(BaseBeforeAfter): + """Simplify conditions using non-inlined let bindings + + A variable may be used as a literal constraint, and be recognized + as being True within the context of the constraint. + """ + + @T.prim_func + def before(A: T.Buffer[4, "bool"]): + for i in T.serial(4): + condition = i < 3 + if condition: + if condition: + A[i] = condition + + @T.prim_func + def expected(A: T.Buffer[4, "bool"]): + for i in T.serial(4): + condition = i < 3 + if condition: + A[i] = True + + if __name__ == "__main__": tvm.testing.main() From 8b5646c4a49154fc20965fdb9207c13c7802ab26 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 3 Jun 2022 13:14:39 -0500 Subject: [PATCH 03/10] [TIR][Arith] Added simplification of constrained if_then_else op This feels like it should definitely be part of RewriteSimplify, but that will require making CanInlineLet be a virtual function. --- src/tir/transforms/simplify.cc | 16 ++++++++++++++++ .../unittest/test_tir_transform_simplify.py | 14 ++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 15649dfe790c..1a61bf23432a 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -100,6 +101,21 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Parent::VisitStmt_(op); } + PrimExpr VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::if_then_else())) { + PrimExpr cond = this->VisitExpr(op->args[0]); + cond = analyzer_->Simplify(Substitute(std::move(cond), non_inlined_bindings_)); + if (const int64_t* as_int = as_const_int(cond)) { + if (*as_int) { + return this->VisitExpr(op->args[1]); + } else { + return this->VisitExpr(op->args[2]); + } + } + } + return Parent::VisitExpr_(op); + } + Stmt VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; return Stmt(); diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 6057de9b4678..0f436f41baca 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -460,5 +460,19 @@ def expected(A: T.Buffer[4, "bool"]): A[i] = True +class TestIfThenElseExpr(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if i < 12: + A[i] = T.if_then_else(i < 12, 1.0, 2.0, dtype="float32") + + @T.prim_func + def expected(A: T.Buffer[16, "float32"]): + for i in T.serial(16): + if i < 12: + A[i] = 1.0 + + if __name__ == "__main__": tvm.testing.main() From 09c1ab5d45457c93ca8e40c9ca8cd47a40780281 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 20 May 2022 10:28:43 -0500 Subject: [PATCH 04/10] [TIR] Implemented HoistExpression transformation This is a generalized form of HoistIfThenElse, which can also hoist Let bindings, or portions of conditional expressions. This will be used in upcoming changes to separate compute loops into a slow loop that handles edge cases and a fast branchless loop. --- include/tvm/tir/transform.h | 8 + python/tvm/tir/transform/transform.py | 70 +++ src/tir/transforms/hoist_expression.cc | 526 ++++++++++++++++++ .../test_tir_transform_hoist_expression.py | 428 ++++++++++++++ 4 files changed, 1032 insertions(+) create mode 100644 src/tir/transforms/hoist_expression.cc create mode 100644 tests/python/unittest/test_tir_transform_hoist_expression.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 24c3cfa78f72..778a2e6bd1b2 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -363,6 +363,14 @@ TVM_DLL Pass PointerValueTypeRewrite(); */ TVM_DLL Pass HoistIfThenElse(); +/*! + * \brief Hoist loop-invariant IfThenElse nodes to + * outside the elligible loops. + * + * \return The pass. + */ +TVM_DLL Pass HoistExpression(); + /*! * \brief Lower cross-thread reduction from thread * bindings to intrinsic function calls. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 802fdc576c41..50e183f81094 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -17,6 +17,8 @@ """Wrapping existing transformations.""" # pylint: disable=invalid-name from typing import Optional +import enum + from . import _ffi_api from . import function_pass as _fpass @@ -612,6 +614,74 @@ def HoistIfThenElse(variant: Optional[str] = None): return _ffi_api.HoistIfThenElse() # type: ignore +class HoistedConditionals(enum.Flag): + """Flags for use in HoistExpressionConfig.conditional_types + + Each bitflag represents a type of expression that should be + hoisted to the outermost loop possible. + """ + + Never = 0 + """ No hoisting of conditionals """ + + IfElseStmt = 1 + """ If set, look for hoist candidates in IfElseStmt """ + + IfElseExpr = 2 + """ If set, look for hoist candidates in tir.if_then_else """ + + BooleanExpression = 4 + """ If set, look for hoist candidates in all boolean expressions """ + + UsingBlockVar = 8 + """ If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x) """ + + All = IfElseStmt | IfElseExpr | BooleanExpression | UsingBlockVar + """ Enable all hoisting of conditionals""" + + +class HoistedLetBindings(enum.Flag): + """Flags for use in HoistExpressionConfig.let_binding_types + + Each bitflag represents a type of let binding expression that should be + hoisted to the outermost loop possible. + """ + + Never = 0 + """ No hoisting of let bindings """ + + RequiredByConditional = 1 + """ Bindings that are used by a hoisted conditional """ + + LetStmt = 2 + """ Bindings occuring in LetStmt """ + + LetExpr = 4 + """ Bindings occuring in Let expressions """ + + All = RequiredByConditional | LetStmt | LetExpr + """ Enable all hoisting of let bindings """ + + +def HoistExpression(): + """Generalized verison of HoistIfThenElse. + + Hoist loop-invariant expressions to outside the eligible loops. + Searches for expressions in: + + * LetStmt bindings + * IfThenElse conditions + * Boolean operators + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + """ + return _ffi_api.HoistExpression() # type: ignore + + def LowerCrossThreadReduction(): """Lower cross-thread reduction from thread bindings to intrinsic function calls. diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc new file mode 100644 index 000000000000..9e9d2e82cd08 --- /dev/null +++ b/src/tir/transforms/hoist_expression.cc @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file hoist_expression.cc + */ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../arith/interval_set.h" +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +enum class HoistedConditionals : int { + kNone = 0, + kIfElseStmt = (1 << 0), + kIfElseExpr = (1 << 1), + kBooleanExpression = (1 << 2), + kUsingBlockVar = (1 << 3), +}; + +enum class HoistedLetBindings : int { + kNone = 0, + kRequiredByCondition = (1 << 0), + kLetStmt = (1 << 1), + kLetExpr = (1 << 2), +}; + +struct HoistExpressionConfigNode : public tvm::AttrsNode { + int hoisted_conditionals; + int hoisted_let_bindings; + + TVM_DECLARE_ATTRS(HoistExpressionConfigNode, "tir.transform.HoistExpressionConfig") { + TVM_ATTR_FIELD(hoisted_conditionals) + .describe("Bitflags for the types of boolean expressions to hoist") + .set_default(int(HoistedConditionals::kIfElseStmt) | int(HoistedConditionals::kIfElseExpr) | + int(HoistedConditionals::kBooleanExpression)); + TVM_ATTR_FIELD(hoisted_let_bindings) + .describe("Bitflags for the types of let bindings to hoist") + .set_default(int(HoistedLetBindings::kRequiredByCondition) | + int(HoistedLetBindings::kLetStmt) | int(HoistedLetBindings::kLetExpr)); + } + + bool FlagSet(HoistedConditionals flag) const { return int(flag) & hoisted_conditionals; } + bool FlagSet(HoistedLetBindings flag) const { return int(flag) & hoisted_let_bindings; } +}; + +class HoistExpressionConfig : public Attrs { + public: + HoistExpressionConfig(int hoisted_conditionals, int hoisted_let_bindings) { + auto node = make_object(); + node->hoisted_conditionals = hoisted_conditionals; + node->hoisted_let_bindings = hoisted_let_bindings; + data_ = std::move(node); + } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistExpressionConfig, Attrs, + HoistExpressionConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(HoistExpressionConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistExpression", HoistExpressionConfig); + +class HoistInfoCollector : public StmtExprVisitor { + public: + struct ConditionInfo { + ConditionInfo(PrimExpr condition, HoistedConditionals hoist_from, bool uses_block_var, + std::unordered_set required_let_bindings, bool generate_else_case) + : condition(condition), + hoist_from(hoist_from), + uses_block_var(uses_block_var), + required_let_bindings(required_let_bindings), + generate_else_case(generate_else_case) {} + PrimExpr condition; + HoistedConditionals hoist_from; + bool uses_block_var; + std::unordered_set required_let_bindings; + bool generate_else_case; + + bool IsEnabled(const HoistExpressionConfig& config) const { + bool valid_source = config->FlagSet(hoist_from); + + bool all_required_bindings_are_hoisted = + required_let_bindings.empty() || + config->FlagSet(HoistedLetBindings::kRequiredByCondition) || + config->FlagSet(HoistedLetBindings::kLetStmt); + + bool valid_block_var_usage = + config->FlagSet(HoistedConditionals::kUsingBlockVar) || !uses_block_var; + return valid_source && all_required_bindings_are_hoisted && valid_block_var_usage; + } + }; + + struct LetBindingInfo { + LetBindingInfo(Var var, PrimExpr value, HoistedLetBindings hoist_from) + : var(var), value(value), hoist_from(hoist_from) {} + Var var; + PrimExpr value; + HoistedLetBindings hoist_from; + + bool IsEnabled(const HoistExpressionConfig& config) const { + return config->FlagSet(hoist_from); + } + }; + + struct HoistInfo { + // The loop variable + Var loop_var; + + // The For or AttrStmt that defines the loop var. + Stmt loop_def; + + // Bindings defined in LetStmt inside the for-loop whose value + // does not depend on the loop variable. These can be hoisted + // outside this for-loop. + std::vector let_bindings; + + // Conditions evaluated inside the for-loop whose value does not + // depend on the loop variable. These can be hoisted outside this + // for loop. These may depend on the let_bindings. + std::vector conditions; + + // Only conditions that impact the entire body of the loop + // hoisted. Conditionals may not be hoisted from inside a + // sequential node to outside. + bool reached_sequential_node{false}; + + // True if the loop variable representing a block variable + // (e.g. blockIdx.x, threadIdx.x), false otherwise. + bool IsBlockVariable() const { return !loop_def.as(); } + }; + + static std::vector Collect(Stmt stmt) { + HoistInfoCollector collector; + collector(stmt); + return collector.completed_loops; + } + + private: + using Parent = StmtExprVisitor; + using Parent::VisitExpr_; + using Parent::VisitStmt_; + + HoistInfoCollector() = default; + + void AttemptHoistConditional(PrimExpr cond, HoistedConditionals hoist_from, + bool generate_else_block = true) { + if (SideEffect(cond) > CallEffectKind::kPure) { + return; + } + if (auto info = FindHoistDestination(cond)) { + if (!info->reached_sequential_node) { + // Record whether this conditional uses any block variables. + bool uses_block_var = active_block_vars.size() && UsesVar(cond, [&](const VarNode* var) { + return active_block_vars.count(var); + }); + + std::unordered_set let_bindings_used; + + for (Var var : UndefinedVars(cond)) { + auto it = let_var_to_let_vars.find(var.get()); + if (it != let_var_to_let_vars.end()) { + let_bindings_used.insert(it->first); + for (auto used : it->second) { + let_bindings_used.insert(used); + } + } + } + info->conditions.push_back(ConditionInfo(cond, hoist_from, uses_block_var, + let_bindings_used, generate_else_block)); + } + } + } + + void VisitExpr_(const AndNode* op) final { + AttemptHoistConditional(op->a, HoistedConditionals::kBooleanExpression); + AttemptHoistConditional(op->b, HoistedConditionals::kBooleanExpression); + Parent::VisitExpr_(op); + } + + void VisitExpr_(const OrNode* op) final { + AttemptHoistConditional(op->a, HoistedConditionals::kBooleanExpression); + AttemptHoistConditional(op->b, HoistedConditionals::kBooleanExpression); + Parent::VisitExpr_(op); + } + + void VisitStmt_(const ForNode* op) final { + active_loops.push_back({op->loop_var, GetRef(op)}); + active_loop_vars.insert(op->loop_var.get()); + + Parent::VisitStmt_(op); + completed_loops.push_back(active_loops.back()); + + active_loop_vars.erase(op->loop_var.get()); + active_loops.pop_back(); + } + + void VisitStmt_(const AttrStmtNode* op) final { + Var var; + if (const auto* node_iter_var = op->node.as()) { + var = node_iter_var->var; + } else if (const auto* node_var = op->node.as()) { + var = GetRef(node_var); + } else { + return Parent::VisitStmt_(op); + } + + active_block_vars.insert(var.get()); + active_loop_vars.insert(var.get()); + active_loops.push_back({var, GetRef(op)}); + + Parent::VisitStmt_(op); + + completed_loops.push_back(active_loops.back()); + active_loops.pop_back(); + + active_loop_vars.erase(var.get()); + active_block_vars.erase(var.get()); + } + + void VisitBinding(Var var, PrimExpr value, HoistedLetBindings hoist_from) { + ICHECK_EQ(let_var_to_loop_vars.count(var.get()), 0) + << "Multiple nested definitions of variable " << var; + ICHECK_EQ(let_var_to_let_vars.count(var.get()), 0) + << "Multiple nested definitions of variable " << var; + + if (auto info = FindHoistDestination(value)) { + if (!info->reached_sequential_node) { + info->let_bindings.push_back(LetBindingInfo(var, value, hoist_from)); + } + } + + // Walk through the loop binding + std::unordered_set loop_vars_used; + std::unordered_set let_bindings_used; + for (Var var : UndefinedVars(value)) { + if (active_loop_vars.count(var.get())) { + loop_vars_used.insert(var.get()); + } else { + auto it = let_var_to_loop_vars.find(var.get()); + if (it != let_var_to_loop_vars.end()) { + for (const VarNode* used : it->second) { + loop_vars_used.insert(used); + } + } + } + + auto it = let_var_to_let_vars.find(var.get()); + if (it != let_var_to_let_vars.end()) { + let_bindings_used.insert(it->first); + for (const VarNode* used : it->second) { + let_bindings_used.insert(used); + } + } + } + + let_var_to_loop_vars[var.get()] = std::move(loop_vars_used); + let_var_to_let_vars[var.get()] = std::move(let_bindings_used); + } + + void VisitStmt_(const LetStmtNode* op) final { + VisitBinding(op->var, op->value, HoistedLetBindings::kLetStmt); + + Parent::VisitStmt_(op); + + let_var_to_loop_vars.erase(op->var.get()); + let_var_to_let_vars.erase(op->var.get()); + } + + void VisitExpr_(const LetNode* op) final { + VisitBinding(op->var, op->value, HoistedLetBindings::kLetExpr); + + Parent::VisitExpr_(op); + + let_var_to_loop_vars.erase(op->var.get()); + let_var_to_let_vars.erase(op->var.get()); + } + + void VisitStmt_(const IfThenElseNode* op) final { + AttemptHoistConditional(op->condition, HoistedConditionals::kIfElseStmt, + op->else_case.defined()); + Parent::VisitStmt_(op); + } + + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::if_then_else())) { + PrimExpr cond = op->args[0]; + AttemptHoistConditional(cond, HoistedConditionals::kIfElseExpr); + } + Parent::VisitExpr_(op); + } + + void VisitStmt_(const SeqStmtNode* op) final { + if (active_loops.size()) { + active_loops.back().reached_sequential_node = true; + } + Parent::VisitStmt_(op); + } + + // Find the loop above which this expression could be hoisted. If + // nullptr, the expression cannot be hoisted. + HoistInfo* FindHoistDestination(PrimExpr expr) { + // Cannot hoist above a loop if we aren't already in a loop. + if (active_loops.empty()) { + return nullptr; + } + + for (auto it = active_loops.rbegin(); it != active_loops.rend(); it++) { + Var loop_var = it->loop_var; + bool uses_loop_var = UsesVar(expr, [&](const VarNode* var) { + if (var == loop_var.get()) { + return true; + } + + auto it = let_var_to_loop_vars.find(var); + if (it == let_var_to_loop_vars.end()) { + return false; + } + + return bool(it->second.count(loop_var.get())); + }); + + if (it->reached_sequential_node || uses_loop_var) { + if (it == active_loops.rbegin()) { + // The innermost loop iterator is used, cannot hoist. + return nullptr; + } else { + // Hoist to just below the loop iterator that is required. + it--; + return &(*it); + } + } + } + + // If no loop variables are used, can hoist above the outermost + // loop. + return &active_loops.front(); + } + + // Current thread_extent bindings of block variables. + std::unordered_set active_block_vars; + + // An ordered list of loops that are currently being visited. + std::vector active_loops; + + // Loops that have already been visited + std::vector completed_loops; + + // Map from a bound variable to the loop variables it depends on. + // Includes indirect usage. + std::unordered_map> let_var_to_loop_vars; + + // Map from a bound variable to the other let bindings it depends on. + // Includes indirect usage. + std::unordered_map> let_var_to_let_vars; + + // Lookup table for the currently active loops. + std::unordered_set active_loop_vars; +}; + +class ExpressionHoister : public arith::IRMutatorWithAnalyzer { + public: + static Stmt Hoist(Stmt stmt, HoistExpressionConfig config) { + auto loop_info = HoistInfoCollector::Collect(stmt); + + arith::Analyzer analyzer; + ExpressionHoister hoister(std::move(loop_info), config, &analyzer); + stmt = hoister(std::move(stmt)); + stmt = ConvertSSA(std::move(stmt)); + return stmt; + } + + private: + using Parent = arith::IRMutatorWithAnalyzer; + using Parent::VisitExpr_; + using Parent::VisitStmt_; + + explicit ExpressionHoister(std::vector loop_info, + HoistExpressionConfig config, arith::Analyzer* analyzer) + : Parent(analyzer), config_(config) { + for (auto& info : loop_info) { + // Mark let bindings to use if they are enabled on their own. + for (const auto& binding : info.let_bindings) { + if (binding.IsEnabled(config)) { + hoisted_let_bindings.insert(binding.var.get()); + } + } + + // Or if they are required by a conditional + if (config->FlagSet(HoistedLetBindings::kRequiredByCondition)) { + for (const auto& conditional : info.conditions) { + if (conditional.IsEnabled(config)) { + for (const auto& var : conditional.required_let_bindings) { + hoisted_let_bindings.insert(var); + } + } + } + } + + loop_info_lookup[info.loop_def.get()] = std::move(info); + } + } + + Stmt WrapHoistedStatements(Stmt stmt, const HoistInfoCollector::HoistInfo& info) { + for (auto cond_it = info.conditions.rbegin(); cond_it != info.conditions.rend(); cond_it++) { + if (cond_it->IsEnabled(config_)) { + if (cond_it->generate_else_case) { + stmt = IfThenElse(cond_it->condition, stmt, stmt); + } else { + stmt = IfThenElse(cond_it->condition, stmt); + } + } + } + for (auto let_it = info.let_bindings.rbegin(); let_it != info.let_bindings.rend(); let_it++) { + if (hoisted_let_bindings.count(let_it->var.get())) { + stmt = LetStmt(let_it->var, let_it->value, stmt); + } + } + + return stmt; + } + + Stmt VisitStmt_(const ForNode* op) final { + Stmt stmt = Parent::VisitStmt_(op); + + auto it = loop_info_lookup.find(op); + ICHECK(it != loop_info_lookup.end()) + << "Could not find pre-pass information for loop over " << op->loop_var; + return WrapHoistedStatements(stmt, it->second); + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + Stmt stmt = Parent::VisitStmt_(op); + + auto it = loop_info_lookup.find(op); + if (it == loop_info_lookup.end()) { + return stmt; + } else { + return WrapHoistedStatements(stmt, it->second); + } + } + + Stmt VisitStmt_(const LetStmtNode* op) final { + if (hoisted_let_bindings.count(op->var.get())) { + return this->VisitStmt(op->body); + } else { + return Parent::VisitStmt_(op); + } + } + + PrimExpr VisitExpr_(const LetNode* op) final { + if (hoisted_let_bindings.count(op->var.get())) { + return this->VisitExpr(op->body); + } else { + return Parent::VisitExpr_(op); + } + } + + HoistExpressionConfig config_; + + std::unordered_map loop_info_lookup; + std::unordered_set hoisted_let_bindings; +}; + +Stmt HoistExpression(Stmt stmt, HoistExpressionConfig config) { + return ExpressionHoister::Hoist(stmt, config); +} + +namespace transform { + +Pass HoistExpression() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto cfg = ctx->GetConfig("tir.HoistExpression"); + + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + n->body = ExpressionHoister::Hoist(std::move(n->body), cfg.value()); + return f; + }; + auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistedExpression", {}); + + return Sequential( + { + insertion_pass, + Simplify(), + RemoveNoOp(), + }, + "tir.HoistExpression"); +} + +TVM_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py new file mode 100644 index 000000000000..3ea51a856d69 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_hoist_expression.py @@ -0,0 +1,428 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +import tvm.testing + +from tvm.script import tir as T +from tvm.tir.transform import HoistExpression, HoistedConditionals, HoistedLetBindings + + +class BaseBeforeAfter: + hoisted_conditionals = tvm.testing.parameter(HoistedConditionals.All) + hoisted_let_bindings = tvm.testing.parameter(HoistedLetBindings.All) + + def test_hoist(self, hoisted_conditionals, hoisted_let_bindings): + before = self.before + before_mod = tvm.IRModule.from_expr(before) + + config = { + "tir.HoistExpression": { + "hoisted_conditionals": hoisted_conditionals.value, + "hoisted_let_bindings": hoisted_let_bindings.value, + } + } + + with tvm.transform.PassContext(config=config): + after_mod = tvm.tir.transform.HoistExpression()(before_mod) + + after = after_mod["main"] + expected = self.expected + + try: + tvm.ir.assert_structural_equal(after, expected) + except ValueError as err: + script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script() + raise ValueError( + f"Function after simplification did not match expected:\n{script}" + ) from err + + +class TestHoistToTop(BaseBeforeAfter): + hoisted_conditionals = tvm.testing.parameter( + HoistedConditionals.IfElseStmt, + HoistedConditionals.All, + ) + + @T.prim_func + def before(A: T.Buffer[(16,), "float32"], n: T.int32): + for i in T.serial(16): + if n != 0: + A[i] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(16,), "float32"], n: T.int32): + if n != 0: + for i in T.serial(16): + A[i] = 0.0 + + +class TestSuppressHoistIfElse(BaseBeforeAfter): + hoisted_conditionals = tvm.testing.parameter( + HoistedConditionals.Never, + HoistedConditionals.IfElseExpr, + ) + + @T.prim_func + def before(A: T.Buffer[(16,), "float32"], n: T.int32): + for i in T.serial(16): + if n != 0: + A[i] = 0.0 + + expected = before + + +class TestHoistBlockVar(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(128, 16), "float32"], n: T.int32): + i = T.env_thread("threadIdx.x") + T.launch_thread(i, 128) + + for j in T.serial(16): + if i < 32: + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(128, 16), "float32"], n: T.int32): + i = T.env_thread("threadIdx.x") + T.launch_thread(i, 128) + + if i < 32: + for j in T.serial(16): + A[i, j] = 0.0 + + +class TestSuppressHoistBlockVar(BaseBeforeAfter): + hoisted_conditionals = tvm.testing.parameter( + HoistedConditionals.All & ~HoistedConditionals.UsingBlockVar + ) + + @T.prim_func + def before(A: T.Buffer[(128, 16), "float32"], n: T.int32): + thread_x = T.env_thread("threadIdx.x") + T.launch_thread(thread_x, 128) + + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + if i < 32: + for j in T.serial(16): + A[i, j] = 0.0 + + expected = before + + +class TestHoistToMiddle(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + if i < 3: + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 3: + for j in T.serial(4): + A[i, j] = 0.0 + + +class TestHoistWithLet(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + condition = i < 3 + if condition: + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + condition = i < 3 + if condition: + for j in T.serial(4): + A[i, j] = 0.0 + + +class TestHoistDisableLet(BaseBeforeAfter): + """As TestHoistWithLet, but forbid hoisting of LetStmt + + Because the condition depends on the let binding, it should no + longer be hoisted. + """ + + hoisted_let_bindings = tvm.testing.parameter(HoistedLetBindings.Never) + + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + condition = i < 3 + if condition: + A[i, j] = 0.0 + + expected = before + + +class TestHoistIfElse(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + if i < 3: + A[i, j] = 0.0 + else: + A[i, j] = 1.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 3: + for j in T.serial(4): + A[i, j] = 0.0 + else: + for j in T.serial(4): + A[i, j] = 1.0 + + +class TestHoistSequentialAssign(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + if i < 3: + A[i, j] = 0.0 + B[i, j] = 0.0 + else: + A[i, j] = 1.0 + B[i, j] = 1.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 3: + for j in T.serial(4): + A[i, j] = 0.0 + B[i, j] = 0.0 + else: + for j in T.serial(4): + A[i, j] = 1.0 + B[i, j] = 1.0 + + +class TestHoistMultiIf(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + for k in T.serial(4): + if j < 3: + if i < 2: + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 2: + for j in T.serial(4): + if j < 3: + for k in T.serial(4): + A[i, j] = 0.0 + + +class TestHoistComplexConditional(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i, j, k in T.grid(4, 4, 4): + if j < 3 and i < 2: + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 2: + for j in T.serial(4): + if j < 3: + for k in T.serial(4): + A[i, j] = 0.0 + + +class TestSuppressSplittingConditional(BaseBeforeAfter): + hoisted_conditionals = tvm.testing.parameter( + HoistedConditionals.All & ~HoistedConditionals.BooleanExpression + ) + + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i, j, k in T.grid(4, 4, 4): + if j < 3 and i < 2: + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i, j in T.grid(4, 4): + if j < 3 and i < 2: + for k in T.serial(4): + A[i, j] = 0.0 + + +class TestHoistMultiIfElse(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + for k in T.serial(4): + if j < 3: + if i < 2: + A[i, j] = 0.0 + else: + A[i, j] = 1.0 + else: + if i < 2: + A[i, j] = 2.0 + else: + A[i, j] = 3.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 2: + for j in T.serial(4): + if j < 3: + for k in T.serial(4): + A[i, j] = 0.0 + else: + for k in T.serial(4): + A[i, j] = 2.0 + else: + for j in T.serial(4): + if j < 3: + for k in T.serial(4): + A[i, j] = 1.0 + else: + for k in T.serial(4): + A[i, j] = 3.0 + + +class TestHoistMultiIfElseDifferentBranches(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + for j in T.serial(4): + for k in T.serial(4): + if j < 3: + if i < 2: + A[i, j] = 0.0 + else: + A[i, j] = 1.0 + else: + if i < 1: + A[i, j] = 2.0 + else: + A[i, j] = 3.0 + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 2: + if i < 1: + for j in T.serial(4): + if j < 3: + for k in T.serial(4): + A[i, j] = 0.0 + else: + for k in T.serial(4): + A[i, j] = 2.0 + else: + for j in T.serial(4): + if j < 3: + for k in T.serial(4): + A[i, j] = 0.0 + else: + for k in T.serial(4): + A[i, j] = 3.0 + else: + for j in T.serial(4): + if j < 3: + for k in T.serial(4): + A[i, j] = 1.0 + else: + for k in T.serial(4): + A[i, j] = 3.0 + + +class TestHoistIfElseExpr(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i, j in T.grid(4, 4): + A[i, j] = T.if_then_else(i < 2, 1.0, 2.0, dtype="float32") + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + if i < 2: + for j in T.serial(4): + A[i, j] = 1.0 + else: + for j in T.serial(4): + A[i, j] = 2.0 + + +class TestSuppressHoistIfElseExpr(TestHoistIfElseExpr): + hoisted_conditionals = tvm.testing.parameter( + HoistedConditionals.All & ~HoistedConditionals.IfElseExpr + ) + + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i, j in T.grid(4, 4): + A[i, j] = T.if_then_else(i < 2, 1.0, 2.0, dtype="float32") + + expected = before + + +class TestHoistLetExpr(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i, j in T.grid(4, 4): + x = T.var("float32") + A[i, j] = tir.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) + + @T.prim_func + def expected(A: T.Buffer[(4, 4), "float32"]): + for i in T.serial(4): + x = T.cast(i + 1, "float32") + for j in T.serial(4): + A[i, j] = 5.0 * x + T.cast(j, "float32") + + +class TestSuppressHoistLetExpr(BaseBeforeAfter): + hoisted_let_bindings = tvm.testing.parameter( + HoistedLetBindings.All & ~HoistedLetBindings.LetExpr + ) + + @T.prim_func + def before(A: T.Buffer[(4, 4), "float32"]): + for i, j in T.grid(4, 4): + x = T.var("float32") + A[i, j] = tir.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) + + expected = before + + +if __name__ == "__main__": + tvm.testing.main() From fbe750a64c261fecec29b021842acf21236fa760 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Jun 2022 10:02:51 -0500 Subject: [PATCH 05/10] [TIR] Expressed HoistIfThenElse as special case of HoistExpression --- src/tir/transforms/hoist_expression.cc | 69 +++ src/tir/transforms/hoist_if_then_else.cc | 438 ------------------ .../unittest/test_tir_transform_hoist_if.py | 61 ++- 3 files changed, 107 insertions(+), 461 deletions(-) delete mode 100644 src/tir/transforms/hoist_if_then_else.cc diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index 9e9d2e82cd08..a508822ec7c3 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -89,6 +89,27 @@ class HoistExpressionConfig : public Attrs { TVM_REGISTER_NODE_TYPE(HoistExpressionConfigNode); TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistExpression", HoistExpressionConfig); +struct HoistIfThenElseConfigNode : public tvm::AttrsNode { + // Would like to replace the typo here from "hosting" to "hoisting", + // but that may impact user configurations. + bool support_block_scope_hosting; + + TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") { + TVM_ATTR_FIELD(support_block_scope_hosting) + .describe("Hoist if cond with block scope variables") + .set_default(false); + } +}; + +class HoistIfThenElseConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs, + HoistIfThenElseConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig); + class HoistInfoCollector : public StmtExprVisitor { public: struct ConditionInfo { @@ -520,6 +541,54 @@ Pass HoistExpression() { TVM_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression); +Pass HoistIfThenElse() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto cfg = ctx->GetConfig("tir.HoistIfThenElse"); + + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + int block_var = + int(cfg.value()->support_block_scope_hosting ? HoistedConditionals::kUsingBlockVar + : HoistedConditionals::kNone); + HoistExpressionConfig config(block_var | int(HoistedConditionals::kIfElseStmt), + int(HoistedLetBindings::kNone)); + n->body = ExpressionHoister::Hoist(std::move(n->body), config); + return f; + }; + auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistIfThenElse", {}); + return Sequential( + { + insertion_pass, + Simplify(), + RemoveNoOp(), + }, + "tir.HoistIfThenElse"); +} + +TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); + +Pass HoistIfThenElseBasic() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + HoistExpressionConfig config(static_cast(HoistedConditionals::kIfElseStmt), + static_cast(HoistedLetBindings::kNone)); + n->body = ExpressionHoister::Hoist(std::move(n->body), config); + return f; + }; + auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistIfThenElseBasic", {}); + return Sequential( + { + insertion_pass, + Simplify(), + RemoveNoOp(), + }, + "tir.HoistIfThenElseBasic"); +} + +TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); + } // namespace transform } // namespace tir diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc deleted file mode 100644 index 4a11a7e90e30..000000000000 --- a/src/tir/transforms/hoist_if_then_else.cc +++ /dev/null @@ -1,438 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file hoist_if_then_else.cc - */ -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "../../arith/interval_set.h" -#include "../../runtime/thread_storage_scope.h" -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -struct HoistIfThenElseConfigNode : public tvm::AttrsNode { - bool support_block_scope_hosting; - - TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") { - TVM_ATTR_FIELD(support_block_scope_hosting) - .describe("Hoist if cond with block scope variables") - .set_default(false); - } -}; - -class HoistIfThenElseConfig : public Attrs { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs, - HoistIfThenElseConfigNode); -}; - -TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig); - -using VarForMap = std::unordered_map; -using HoistForIfTuple = std::tuple; - -/* - * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. - * For example, given the following block: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * if (likely(i*2 < 4)) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. - * Then we hoist IfThenElse stmt by one For stmt each step: - * - * Step 1: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * if (likely(i*2 < 4)) - * for (k = 0; k < 5; k++) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * Step 2: - * for (i = 0; i < 3; i++) - * if (likely(i*2 < 4)) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * In this pass, we only continue detecting possible hoisting chance when visiting For, - * IfThenElse or AttrStmt Node. For example, for the following block: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * A[i + j] = A[i + j] - 1 - * for (k = 0; k < 5; k++) - * if (likely(i*2 < 4)) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * Only the For with k variable will be considered and the resulting stmt would be: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * A[i + j] = A[i + j] - 1 - * if (likely(i*2 < 4)) - * for (k = 0; k < 5; k++) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following - * block won't be optimized: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * if (likely(i*2 < 4)) - * A[3*i+2j+k] = B[7*i+3j+k] - * if (likely(j > 2)) - * A[i+j+k] = B[i+j+k] - * - * - * This pass do hoisting for Block scope variables also. - * As below: - * Attr(IterVar: threadIdx.x) - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * if (likely(threadIdx.x < 3)) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * Will be transformed to as below: - * Attr(IterVar: threadIdx.x) - * if (likely(threadIdx.x < 3)) - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * A[3*i+2j+k] = B[7*i+3j+k] - * - */ - -// Select potential candidate IRs that can be hoisted. -class HoistCandidateSelector final : public StmtExprVisitor { - public: - explicit HoistCandidateSelector(bool support_block_scope_hosting) - : support_block_scope_hosting_(support_block_scope_hosting) { - InitRecorder(); - } - HoistCandidateSelector() { InitRecorder(); } - - void VisitStmt_(const ForNode* op) final { - // If already recording complete, - // then stop tracing - if (RecordingComplete()) { - return; - } - - // Check if it is first for loop, then start the recorder - StartOrAddRecord(GetRef(op)); - StmtExprVisitor::VisitStmt_(op); - RemoveRecord(GetRef(op)); - } - - void VisitStmt_(const SeqStmtNode* op) final { - // If SeqStmt is encountered in the middle of recording - // then need to purge all, as it can not be hoisted - if (IsRecordingOn()) { - ResetRecorderInternal(); - } - StmtExprVisitor::VisitStmt_(op); - } - - void VisitStmt_(const AttrStmtNode* op) final { - // Maintain list of all vars in AttrStmt - // To stop hoisting if any of the block variables are used. - // - // In case we want to use hoisting in between certain passes - // which have interdependencies of the positioning of if nodes with scope var - // it is better to disable this section - if (support_block_scope_hosting_) { - if (IsRecordingOn()) { - StartOrAddRecord(GetRef(op)); - StmtExprVisitor::VisitStmt_(op); - RemoveRecord(GetRef(op)); - return; - } else { - return StmtExprVisitor::VisitStmt_(op); - } - } - UpdateAttrVarList(op); - StmtExprVisitor::VisitStmt_(op); - RemoveAttrVarList(op); - } - - void VisitStmt_(const IfThenElseNode* op) final { - if (!IsRecordingOn()) { - StmtExprVisitor::VisitStmt_(op); - return; - } - - is_if_cond_ = true; - StmtExprVisitor::VisitExpr(op->condition); - is_if_cond_ = false; - - if (CheckValidIf()) { - // Check corresponding for loop - int match_for_loop_pos = -1; - for (auto var : if_var_list_) { - for (int i = 0; i < static_cast(ordered_list_.size()); ++i) { - if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) { - if (match_for_loop_pos < i) { - match_for_loop_pos = i; - } - } - } - } - // If none of the for loop has the matching loop variable as if condition, - // then the if node need to be hoisted on top of all, provided no parent loop exists. - int target_for_pos = GetNextLoopPos(match_for_loop_pos); - - // Check if valid position - if (target_for_pos >= 0) { - StopAndAddRecord(static_cast(ordered_list_[target_for_pos]), op); - if_var_list_.clear(); - return; - } - } - - if_var_list_.clear(); - StmtExprVisitor::VisitStmt_(op); - StopRecording(); - } - - void VisitExpr_(const VarNode* op) final { - if (is_if_cond_) { - if_var_list_.emplace_back(op); - } - } - - HoistForIfTuple hoist_for_if_recorder; - - void ResetRecorder() { - ResetRecorderInternal(); - - // Reset Block scope vars also here - attr_var_list_.clear(); - } - - bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); } - - const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); } - - const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); } - - private: - void ResetRecorderInternal() { - if (is_recorder_on_) { - ICHECK_GT(ordered_list_.size(), 0); - is_recorder_on_ = false; - } - ordered_list_.clear(); - var_for_map_.clear(); - hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); - } - bool CheckValidIf() { - // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop - // hoisting - return ((!if_var_list_.empty()) && (!CheckAttrVar())); - } - - int GetNextLoopPos(int cur_pos) { - for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) { - if (ordered_list_[i]->IsInstance()) { - return i; - } - } - return -1; - } - - void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); } - - void StopRecording() { is_recorder_on_ = false; } - - bool IsRecordingOn() { return is_recorder_on_; } - - void StartOrAddRecord(const ObjectRef& op) { - is_recorder_on_ = true; - if (const auto* node = op.as()) { - if (!var_for_map_.count(node->loop_var.get())) - var_for_map_.insert({node->loop_var.get(), node}); - ordered_list_.emplace_back(op.get()); - } else if (const auto* node = op.as()) { - if (const auto* iv = node->node.as()) { - ordered_list_.emplace_back(iv->var.get()); - } else if (const auto* iv = node->node.as()) { - ordered_list_.emplace_back(iv); - } - } - } - - void RemoveRecord(const ObjectRef& op) { - StopRecording(); - if (const auto* node = op.as()) var_for_map_.erase(node->loop_var.get()); - if (ordered_list_.size() > 0) ordered_list_.pop_back(); - } - - void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) { - hoist_for_if_recorder = std::make_tuple(true, for_node, if_node); - StopRecording(); - } - - void UpdateAttrVarList(const AttrStmtNode* op) { - if (const auto* iv = op->node.as()) { - attr_var_list_.insert(iv->var.get()); - } else if (const auto* iv = op->node.as()) { - attr_var_list_.insert(iv); - } - } - - void RemoveAttrVarList(const AttrStmtNode* op) { - if (const auto* iv = op->node.as()) { - attr_var_list_.erase(iv->var.get()); - } else if (const auto* iv = op->node.as()) { - attr_var_list_.erase(iv); - } - } - - bool CheckAttrVar() { - for (auto var : if_var_list_) { - if (attr_var_list_.count(var)) { - return true; - } - } - return false; - } - - // Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence - std::vector ordered_list_; - std::vector if_var_list_; - std::unordered_set attr_var_list_; - VarForMap var_for_map_; - - bool is_if_cond_{false}; - bool is_recorder_on_{false}; - bool support_block_scope_hosting_{false}; -}; - -class IfThenElseHoister : public StmtMutator { - public: - IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {} - explicit IfThenElseHoister(bool support_block_scope_hosting) - : hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {} - - Stmt VisitAndMutate(Stmt stmt) { - hoist_selector_(stmt); - Stmt stmt_copy = std::move(stmt); - - while (hoist_selector_.RecordingComplete()) { - target_for_ = hoist_selector_.GetTargetForNode(); - target_if_ = hoist_selector_.GetTargetIfNode(); - - stmt_copy = operator()(stmt_copy); - - hoist_selector_.ResetRecorder(); - hoist_selector_(stmt_copy); - } - - // Support SSA Form - stmt_copy = ConvertSSA(stmt_copy); - return stmt_copy; - } - - Stmt VisitStmt_(const ForNode* op) final { - if ((!is_updating_) && (target_for_ == op)) { - is_updating_ = true; - is_then_case_ = true; - Stmt then_case = StmtMutator::VisitStmt_(op); - is_then_case_ = false; - Stmt else_case = Stmt(); - if (target_if_->else_case.defined()) { - else_case = StmtMutator::VisitStmt_(op); - } - is_updating_ = false; - return IfThenElse(target_if_->condition, then_case, else_case); - } - return StmtMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const IfThenElseNode* op) final { - if (is_updating_ && (target_if_ == op)) { - if (is_then_case_) { - return StmtMutator::VisitStmt(op->then_case); - } else if (op->else_case.defined()) { - return StmtMutator::VisitStmt(op->else_case); - } - } - return StmtMutator::VisitStmt_(op); - } - - private: - bool is_updating_{false}; - bool is_then_case_{false}; - HoistCandidateSelector hoist_selector_; - const ForNode* target_for_; - const IfThenElseNode* target_if_; -}; - -Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) { - return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt); -} -Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); } - -namespace transform { - -Pass HoistIfThenElse() { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - auto cfg = ctx->GetConfig("tir.HoistIfThenElse"); - - if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); - } - n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {}); -} - -Pass HoistIfThenElseBasic() { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = HoistIfThenElse(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); - -TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index b111e2be75c7..0270500828b8 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -61,6 +61,10 @@ def _visit(op): var_list.clear() +def _opaque_eval(var): + return tvm.tir.Evaluate(tvm.tir.call_extern("int32", "dummy", var)) + + def test_hoist_top_for(): ib = tvm.tir.ir_builder.create() l = te.var("l") @@ -72,9 +76,9 @@ def test_hoist_top_for(): with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): - ib.emit(tvm.tir.Evaluate(m)) + ib.emit(_opaque_eval(m)) with ib.else_scope(): - ib.emit(tvm.tir.Evaluate(n)) + ib.emit(_opaque_eval(n)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) @@ -99,13 +103,14 @@ def test_hoist_multi_var_if(): with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i + j < 2)): - ib.emit(tvm.tir.Evaluate(m)) + ib.emit(_opaque_eval(m)) with ib.else_scope(): - ib.emit(tvm.tir.Evaluate(n)) + ib.emit(_opaque_eval(n)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + new_mod = tvm.tir.transform.HoistIfThenElse()(mod) + new_stmt = new_mod["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")), @@ -127,9 +132,9 @@ def test_hoist_no_match_for(): data[i * 3 + j] = data[i * 3 + j] + 0.5 with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): - ib.emit(tvm.tir.Evaluate(m)) + ib.emit(_opaque_eval(m)) with ib.else_scope(): - ib.emit(tvm.tir.Evaluate(n)) + ib.emit(_opaque_eval(n)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) @@ -153,7 +158,7 @@ def test_no_else(): with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): - ib.emit(tvm.tir.Evaluate(m)) + ib.emit(_opaque_eval(m)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) @@ -277,13 +282,14 @@ def test_multi_if(): with ib.for_range(0, 10, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.for_range(0, 10, "k") as k: - with ib.if_scope(i >= 3): - with ib.if_scope(j >= 3): + with ib.if_scope(3 <= i): + with ib.if_scope(3 <= j): data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + new_mod = tvm.tir.transform.HoistIfThenElse()(mod) + new_stmt = new_mod["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.IfThenElse", ("j",)): (("tir.For", "k"), None), @@ -302,7 +308,7 @@ def test_no_hoisting_1(): with ib.for_range(0, 10, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.for_range(0, 10, "k") as k: - with ib.if_scope(k >= 3): + with ib.if_scope(k <= 3): data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 stmt = ib.get() @@ -326,7 +332,7 @@ def test_no_hoisting_2(): with ib.for_range(0, 10, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.for_range(0, 10, "k") as k: - with ib.if_scope(i >= 3): + with ib.if_scope(i <= 3): data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3 data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 @@ -342,6 +348,7 @@ def test_no_hoisting_2(): tvm.ir.assert_structural_equal(new_stmt, stmt) +@pytest.mark.xfail(reason="Inconsistent thread_extent", strict=True) def test_no_hoisting_3(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) @@ -410,6 +417,7 @@ def test_no_hoisting_4(): tvm.ir.assert_structural_equal(new_stmt, stmt) +@pytest.mark.xfail(reason="Inconsistent thread_extent", strict=True) def test_no_hoisting_5(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) @@ -522,15 +530,17 @@ def test_hoisting_block_scope_1(): s[B.op].bind(xi, te.thread_axis("threadIdx.y")) s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) - func = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None)["main"] - stmt = func.body - new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + mod = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + stmt = mod["main"].body + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): - new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) @@ -558,6 +568,10 @@ def test_hoisting_block_scope_2(): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + stmt = mod["main"].body + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) @@ -565,10 +579,10 @@ def test_hoisting_block_scope_2(): config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body - # tvm.ir.assert_structural_equal(new_stmt, stmt) assert not tvm.ir.structural_equal(new_stmt, stmt) +@pytest.mark.xfail(reason="Inconsistent thread_extent", strict=True) def test_hoisting_block_scope_3(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) @@ -601,7 +615,6 @@ def test_hoisting_block_scope_3(): config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body - # tvm.ir.assert_structural_equal(new_stmt, stmt) assert not tvm.ir.structural_equal(new_stmt, stmt) @@ -622,15 +635,17 @@ def test_hoisting_block_scope_4(): s[C].pragma(xo2, "parallel_stride_pattern") s[C].pragma(xo2, "parallel_barrier_when_finish") s[C].vectorize(xi) - func = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)["main"] - stmt = func.body - new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None) + mod = tvm.tir.transform.Simplify()(mod) + + stmt = mod["main"].body + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): - new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) From fa118814eeb6752bef0ecad49afd4a44785fdc21 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Jun 2022 10:47:49 -0500 Subject: [PATCH 06/10] Lint fixes --- src/tir/transforms/hoist_expression.cc | 32 +++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index a508822ec7c3..76bf8ebe72c4 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -62,16 +62,22 @@ struct HoistExpressionConfigNode : public tvm::AttrsNode(HoistedConditionals::kIfElseStmt) | + static_cast(HoistedConditionals::kIfElseExpr) | + static_cast(HoistedConditionals::kBooleanExpression)); TVM_ATTR_FIELD(hoisted_let_bindings) .describe("Bitflags for the types of let bindings to hoist") - .set_default(int(HoistedLetBindings::kRequiredByCondition) | - int(HoistedLetBindings::kLetStmt) | int(HoistedLetBindings::kLetExpr)); + .set_default(static_cast(HoistedLetBindings::kRequiredByCondition) | + static_cast(HoistedLetBindings::kLetStmt) | + static_cast(HoistedLetBindings::kLetExpr)); } - bool FlagSet(HoistedConditionals flag) const { return int(flag) & hoisted_conditionals; } - bool FlagSet(HoistedLetBindings flag) const { return int(flag) & hoisted_let_bindings; } + bool FlagSet(HoistedConditionals flag) const { + return static_cast(flag) & hoisted_conditionals; + } + bool FlagSet(HoistedLetBindings flag) const { + return static_cast(flag) & hoisted_let_bindings; + } }; class HoistExpressionConfig : public Attrs { @@ -356,7 +362,7 @@ class HoistInfoCollector : public StmtExprVisitor { for (auto it = active_loops.rbegin(); it != active_loops.rend(); it++) { Var loop_var = it->loop_var; - bool uses_loop_var = UsesVar(expr, [&](const VarNode* var) { + bool uses_loop_var = UsesVar(expr, [&](const VarNode* var) -> bool { if (var == loop_var.get()) { return true; } @@ -366,7 +372,7 @@ class HoistInfoCollector : public StmtExprVisitor { return false; } - return bool(it->second.count(loop_var.get())); + return it->second.count(loop_var.get()); }); if (it->reached_sequential_node || uses_loop_var) { @@ -549,11 +555,11 @@ Pass HoistIfThenElse() { if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } - int block_var = - int(cfg.value()->support_block_scope_hosting ? HoistedConditionals::kUsingBlockVar - : HoistedConditionals::kNone); - HoistExpressionConfig config(block_var | int(HoistedConditionals::kIfElseStmt), - int(HoistedLetBindings::kNone)); + int block_var = static_cast(cfg.value()->support_block_scope_hosting + ? HoistedConditionals::kUsingBlockVar + : HoistedConditionals::kNone); + HoistExpressionConfig config(block_var | static_cast(HoistedConditionals::kIfElseStmt), + static_cast(HoistedLetBindings::kNone)); n->body = ExpressionHoister::Hoist(std::move(n->body), config); return f; }; From 9e70691ebfdbed638467005f3cf28b84e104eac0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Jun 2022 16:28:03 -0500 Subject: [PATCH 07/10] Fixed breakage in tvmc unit test that relied on pass type --- tests/python/driver/tvmc/test_pass_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/driver/tvmc/test_pass_config.py b/tests/python/driver/tvmc/test_pass_config.py index f928c8a31293..034f761f1d6b 100644 --- a/tests/python/driver/tvmc/test_pass_config.py +++ b/tests/python/driver/tvmc/test_pass_config.py @@ -23,6 +23,7 @@ from tvm.driver.tvmc import TVMCException from tvm.driver.tvmc.pass_config import parse_configs from tvm.tir.transform import PrimFuncPass +from tvm.ir.transform import Sequential def test_config_invalid_format(): @@ -89,7 +90,8 @@ def test_add_lower_pass_multi_built_in_pass(): assert isinstance(configs["tir.add_lower_pass"][0][1], PrimFuncPass) # opt_level: 1, pass: tir.transform.HoistIfThenElse assert configs["tir.add_lower_pass"][1][0] == 1 - assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass) + assert isinstance(configs["tir.add_lower_pass"][1][1], Sequential) + assert configs["tir.add_lower_pass"][1][1].pass_info.name == "tir.HoistIfThenElse" # opt_level: 2, pass: tir.transform.LoopPartition assert configs["tir.add_lower_pass"][2][0] == 2 assert isinstance(configs["tir.add_lower_pass"][2][1], PrimFuncPass) From 12f4926b8a04e7ef5db2cba1b234323f335d6cf4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Jun 2022 08:27:23 -0500 Subject: [PATCH 08/10] More accurate handling of kUsingBlockVar Didn't correctly reproduce previous behavior. In addition to preventing hoisting of expressions that use a block variable (e.g. threadIdx.x), should also prevent hoisting of expressions across a "thread_extent" AttrStmt. --- src/tir/transforms/hoist_expression.cc | 19 +++++--- .../test_tir_transform_hoist_expression.py | 48 +++++++++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index 76bf8ebe72c4..2ba800d44345 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -185,8 +185,8 @@ class HoistInfoCollector : public StmtExprVisitor { bool IsBlockVariable() const { return !loop_def.as(); } }; - static std::vector Collect(Stmt stmt) { - HoistInfoCollector collector; + static std::vector Collect(Stmt stmt, HoistExpressionConfig config) { + HoistInfoCollector collector(config); collector(stmt); return collector.completed_loops; } @@ -196,7 +196,7 @@ class HoistInfoCollector : public StmtExprVisitor { using Parent::VisitExpr_; using Parent::VisitStmt_; - HoistInfoCollector() = default; + HoistInfoCollector(HoistExpressionConfig config) : config(config) {} void AttemptHoistConditional(PrimExpr cond, HoistedConditionals hoist_from, bool generate_else_block = true) { @@ -375,9 +375,12 @@ class HoistInfoCollector : public StmtExprVisitor { return it->second.count(loop_var.get()); }); - if (it->reached_sequential_node || uses_loop_var) { + bool is_disabled_hoist_across_block_var = + !config->FlagSet(HoistedConditionals::kUsingBlockVar) && it->IsBlockVariable(); + + if (it->reached_sequential_node || uses_loop_var || is_disabled_hoist_across_block_var) { if (it == active_loops.rbegin()) { - // The innermost loop iterator is used, cannot hoist. + // Cannot hoist beyond the innermost loop iterator. return nullptr; } else { // Hoist to just below the loop iterator that is required. @@ -392,6 +395,10 @@ class HoistInfoCollector : public StmtExprVisitor { return &active_loops.front(); } + // The user-provided config describing which expressions should be + // hoisted. + HoistExpressionConfig config; + // Current thread_extent bindings of block variables. std::unordered_set active_block_vars; @@ -416,7 +423,7 @@ class HoistInfoCollector : public StmtExprVisitor { class ExpressionHoister : public arith::IRMutatorWithAnalyzer { public: static Stmt Hoist(Stmt stmt, HoistExpressionConfig config) { - auto loop_info = HoistInfoCollector::Collect(stmt); + auto loop_info = HoistInfoCollector::Collect(stmt, config); arith::Analyzer analyzer; ExpressionHoister hoister(std::move(loop_info), config, &analyzer); diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py index 3ea51a856d69..e52eb4c5063c 100644 --- a/tests/python/unittest/test_tir_transform_hoist_expression.py +++ b/tests/python/unittest/test_tir_transform_hoist_expression.py @@ -124,6 +124,54 @@ def before(A: T.Buffer[(128, 16), "float32"], n: T.int32): expected = before +class TestHoistAcrossBlockVar(BaseBeforeAfter): + @T.prim_func + def before(A: T.Buffer[(128, 16), "float32"], n: T.int32): + thread_x = T.env_thread("threadIdx.x") + T.launch_thread(thread_x, 128) + + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + if n == 0: + for j in T.serial(16): + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(128, 16), "float32"], n: T.int32): + thread_x = T.env_thread("threadIdx.x") + + if n == 0: + T.launch_thread(thread_x, 128) + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + for j in T.serial(16): + A[i, j] = 0.0 + + +class TestSuppressHoistAcrossBlockVar(BaseBeforeAfter): + hoisted_conditionals = tvm.testing.parameter( + HoistedConditionals.All & ~HoistedConditionals.UsingBlockVar + ) + + @T.prim_func + def before(A: T.Buffer[(128, 16), "float32"], n: T.int32): + thread_x = T.env_thread("threadIdx.x") + T.launch_thread(thread_x, 128) + + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + for j in T.serial(16): + if n == 0: + A[i, j] = 0.0 + + @T.prim_func + def expected(A: T.Buffer[(128, 16), "float32"], n: T.int32): + thread_x = T.env_thread("threadIdx.x") + + T.launch_thread(thread_x, 128) + if n == 0: + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + for j in T.serial(16): + A[i, j] = 0.0 + + class TestHoistToMiddle(BaseBeforeAfter): @T.prim_func def before(A: T.Buffer[(4, 4), "float32"]): From e54ceee14d2e8e448c2d38e2b90f9aa14896916f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 23 Jun 2022 08:02:27 -0500 Subject: [PATCH 09/10] Updated comment for HoistExpression pass --- include/tvm/tir/transform.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 68dda0c8374b..2f3ea75e6c40 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -365,9 +365,14 @@ TVM_DLL Pass PointerValueTypeRewrite(); TVM_DLL Pass HoistIfThenElse(); /*! - * \brief Hoist loop-invariant IfThenElse nodes to + * \brief Hoist loop-invariant expressions nodes to * outside the elligible loops. * + * Can hoist conditionals used in IfThenElse statements and + * expressions, bindings of variables in Let statements and + * expressions, or boolean expressions, configurable to enable/disable + * each hoistable type. + * * \return The pass. */ TVM_DLL Pass HoistExpression(); From 22d7a3d7db3813144ed6d4cfe3d2a6eedf3f9a94 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 23 Jun 2022 10:32:07 -0500 Subject: [PATCH 10/10] Fix linting error --- src/tir/transforms/hoist_expression.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index 2ba800d44345..ffc58f3a42b7 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -196,7 +196,7 @@ class HoistInfoCollector : public StmtExprVisitor { using Parent::VisitExpr_; using Parent::VisitStmt_; - HoistInfoCollector(HoistExpressionConfig config) : config(config) {} + explicit HoistInfoCollector(HoistExpressionConfig config) : config(config) {} void AttemptHoistConditional(PrimExpr cond, HoistedConditionals hoist_from, bool generate_else_block = true) {