diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 82533a2f9f5a..35752738c33d 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -242,6 +242,17 @@ def UnrollLoop(): return _ffi_api.UnrollLoop() # type: ignore +def ReduceBranchingThroughOvercompute(): + """Reduce branching by introducing overcompute + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ReduceBranchingThroughOvercompute() # type: ignore + + def RemoveNoOp(): """Remove No Op from the Stmt. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index c9d92f992564..f1838f5a9099 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1388,8 +1388,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { EQ ret = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); op = ret.get(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + return const_res.value(); + } + if (auto match = TryMatchLiteralConstraint(ret)) { + return match.value(); + } return ApplyRewriteRules(ret); } @@ -1419,7 +1423,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { TVM_TRY_REWRITE(x - c1 == 0, x == c1); TVM_TRY_REWRITE(c1 - x == 0, x == c1); TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1); - TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0); + TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0); } return std::move(ret); } diff --git a/src/tir/transforms/reduce_branching_through_overcompute.cc b/src/tir/transforms/reduce_branching_through_overcompute.cc new file mode 100644 index 000000000000..8c8824719276 --- /dev/null +++ b/src/tir/transforms/reduce_branching_through_overcompute.cc @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file reduce_branching_through_overcompute.cc + * + * \brief Attempt to remove conditional statements by introducing + * extra computations that do not impact the final results. + */ + +#include +#include + +#include + +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../analysis/control_flow_graph.h" +#include "remove_no_op.h" +#include "simplify.h" + +namespace tvm { +namespace tir { + +struct ReduceBranchingThroughOvercomputeConfigNode + : public tvm::AttrsNode { + bool use_dataflow_analysis; + + TVM_DECLARE_ATTRS(ReduceBranchingThroughOvercomputeConfigNode, + "tir.transform.ReduceBranchingThroughOvercomputeConfig") { + TVM_ATTR_FIELD(use_dataflow_analysis) + .describe( + "If true, known buffer values are propagated and used " + "to statically prove that overcompute is valid.") + .set_default(false); + } +}; + +class ReduceBranchingThroughOvercomputeConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReduceBranchingThroughOvercomputeConfig, Attrs, + ReduceBranchingThroughOvercomputeConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(ReduceBranchingThroughOvercomputeConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.ReduceBranchingThroughOvercompute", + ReduceBranchingThroughOvercomputeConfig); + +struct ElseBranchFiller : StmtExprMutator { + Stmt VisitStmt_(const IfThenElseNode* op) override { + IfThenElse ret = Downcast(StmtExprMutator::VisitStmt_(op)); + if (ret->else_case.defined()) { + return std::move(ret); + } else { + auto new_else_clause = Evaluate(0); + new_else_clauses.insert(new_else_clause); + return IfThenElse(ret->condition, ret->then_case, new_else_clause); + } + } + + std::unordered_set new_else_clauses; +}; + +class ElseBranchStripper : public StmtExprMutator { + public: + ElseBranchStripper( + const std::unordered_set& new_else_clauses) + : new_else_clauses_(new_else_clauses) {} + + private: + Stmt VisitStmt_(const IfThenElseNode* op) override { + IfThenElse ret = Downcast(StmtExprMutator::VisitStmt_(op)); + auto as_eval = ret->else_case.as(); + if (as_eval && new_else_clauses_.count(GetRef(as_eval))) { + return IfThenElse(ret->condition, ret->then_case); + } else { + return std::move(ret); + } + } + + const std::unordered_set& new_else_clauses_; +}; + +class BranchReducer : public arith::IRMutatorWithAnalyzer { + public: + static Stmt Apply(Stmt stmt, const std::optional& touch_pattern) { + arith::Analyzer analyzer; + BranchReducer visitor(&analyzer, touch_pattern); + return visitor(std::move(stmt)); + } + + private: + using Parent = IRMutatorWithAnalyzer; + using Parent::VisitStmt; + using Parent::VisitStmt_; + + BranchReducer(arith::Analyzer* analyzer, const std::optional& touch_pattern) + : Parent(analyzer), touch_pattern_(touch_pattern) {} + + Stmt VisitStmt_(const IfThenElseNode* op) final { + IfThenElse cond = Downcast(Parent::VisitStmt_(op)); + + auto is_special_case = [&](PrimExpr condition, Stmt general_case, Stmt special_case) -> bool { + condition = analyzer_->rewrite_simplify(condition); + With constraint(analyzer_, condition); + Stmt stmt = RemoveNoOp(general_case, analyzer_, touch_pattern_, special_case.get()); + return StructuralEqual()(stmt, special_case); + }; + + ICHECK(cond->else_case.defined() || !touch_pattern_.has_value()) + << "Temp assert, should be true whenever touch pattern is available"; + Stmt else_case = cond->else_case.value_or(Evaluate(0)); + + if (is_special_case(cond->condition, else_case, cond->then_case)) { + return else_case; + } else if (is_special_case(!cond->condition, cond->then_case, else_case)) { + return cond->then_case; + } else { + return std::move(cond); + } + } + + private: + const std::optional& touch_pattern_; +}; + +namespace transform { + +Pass ReduceBranchingThroughOvercompute() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + arith::Analyzer analyzer; + + ReduceBranchingThroughOvercomputeConfig config = + ctx->GetConfig( + "tir.ReduceBranchingThroughOvercompute") + .value_or(AttrsWithDefaultValues()); + + auto* n = f.CopyOnWrite(); + + std::optional touch_pattern = std::nullopt; + ElseBranchFiller else_branch_filler; + if (config->use_dataflow_analysis) { + n->body = else_branch_filler(std::move(n->body)); + touch_pattern.emplace(n->body); + } + + n->body = BranchReducer::Apply(std::move(n->body), touch_pattern); + + if (config->use_dataflow_analysis) { + n->body = ElseBranchStripper(else_branch_filler.new_else_clauses)(std::move(n->body)); + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute") + .set_body_typed(ReduceBranchingThroughOvercompute); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 3374f975f5ac..430c1f41bfaf 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -220,16 +220,21 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { touch_pattern_->RemoveStore(store); return only_side_effects(); } + } - // A write whose destination is known to already contain the - // values to be written is a no-op. - PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); - - PrimExpr simplified = - touch_pattern_->SimplifyInContext(stores_existing_value, context, analyzer_); - if (auto* as_int = as_const_int(simplified); as_int && *as_int) { - return only_side_effects(); - } + // A write whose destination is known to already contain the + // values to be written is a no-op. + // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); + PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0; + if (touch_pattern_.has_value()) { + Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); + stores_existing_value = + touch_pattern_->SimplifyInContext(stores_existing_value, context_arg, analyzer_); + } else { + stores_existing_value = analyzer_->Simplify(stores_existing_value); + } + if (is_one(stores_existing_value)) { + return only_side_effects(); } // If the stored value is a load from the same location, the @@ -293,6 +298,11 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { const StmtNode* context_; }; +Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional touch_pattern, + const StmtNode* context) { + return NoOpRemover::Apply(std::move(stmt), analyzer, std::move(touch_pattern), context); +} + namespace transform { Pass RemoveNoOp() { @@ -306,10 +316,6 @@ Pass RemoveNoOp() { } arith::Analyzer analyzer; - analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( - arith::RewriteSimplifier::kTransitivelyProveInequalities | - arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | - arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); auto* n = f.CopyOnWrite(); n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr); diff --git a/src/tir/transforms/remove_no_op.h b/src/tir/transforms/remove_no_op.h new file mode 100644 index 000000000000..e24c32b5da18 --- /dev/null +++ b/src/tir/transforms/remove_no_op.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file remove_no_op.h + * \brief Helper functions to construct and compose IR nodes. + */ +#ifndef TVM_TIR_TRANSFORMS_REMOVE_NO_OP_H_ +#define TVM_TIR_TRANSFORMS_REMOVE_NO_OP_H_ + +#include +#include + +#include + +#include "../analysis/control_flow_graph.h" + +namespace tvm { +namespace tir { + +/* \brief Remove no-ops from the statement + * + * Applies the same behavior as the tir.transform.RemoveNoOp pass, but + * on a single statement, usable as a subroutine in other passes. + * + * \param stmt The TIR statement from which to remove no-ops + * + * \param analyzer The analyzer to use while proving no-ops + * + * \param control_flow The analyzed control-flow graph, which contains + * the `stmt` to be analyzed. If provided, known buffer values will + * be used to remove no-ops. (e.g. Removing `buf[i] = 0` in cases + * where `buf[i]` is known to already contain zero.) If nullptr, + * known buffer values will not be used. + * + * \return The modified statement with no-ops removed + */ +Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, + std::optional touch_pattern = std::nullopt, + const StmtNode* context = nullptr); + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_TRANSFORMS_REMOVE_NO_OP_H_ diff --git a/src/tir/transforms/simplify.h b/src/tir/transforms/simplify.h new file mode 100644 index 000000000000..43afc5e48dcb --- /dev/null +++ b/src/tir/transforms/simplify.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file simplify.h + * \brief Helper functions to construct and compose IR nodes. + */ +#ifndef TVM_TIR_TRANSFORMS_SIMPLIFY_H_ +#define TVM_TIR_TRANSFORMS_SIMPLIFY_H_ + +#include +#include + +namespace tvm { +namespace tir { + +/* \brief Simplifies the statement + * + * Applies the same behavior as the tir.transform.Simplify pass, but + * on a single statement, usable as a subroutine in other passes. + */ +Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer); + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_TRANSFORMS_SIMPLIFY_H_ diff --git a/tests/python/unittest/test_tir_transform_reduce_branching_through_overcompute.py b/tests/python/unittest/test_tir_transform_reduce_branching_through_overcompute.py new file mode 100644 index 000000000000..13fbcc7594ec --- /dev/null +++ b/tests/python/unittest/test_tir_transform_reduce_branching_through_overcompute.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T + +import pytest + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + use_dataflow_analysis = False + + def transform(self): + def inner(mod): + config = { + "tir.ReduceBranchingThroughOvercompute": { + "use_dataflow_analysis": self.use_dataflow_analysis, + } + } + with tvm.transform.PassContext(config=config): + mod = tvm.tir.transform.ReduceBranchingThroughOvercompute()(mod) + return mod + + return inner + + +class TestIntroduceNoOp(BaseBeforeAfter): + """Remove a conditional by introducing a no-op + + If the else_case can have a no-op added in order to be identical + to the then_case, then the conditional can be removed. + """ + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 14: + A[i] = 1 + T.evaluate(0) + else: + A[i] = 1 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 1 + T.evaluate(0) + + +class TestIntroduceAdditionOfZero(BaseBeforeAfter): + """Insert a conditionally no-op statement + + Overcompute doesn't need to explicitly be a no-op, and can be + something that simplifies to a no-op. Here, when i==0, the + expression simplifies to ``A[0] = A[0]``, which is a no-op. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[1, "int32"]): + for i in T.serial(16): + if i > 0: + A[0] = A[0] + i * i + + def expected(A: T.Buffer[1, "int32"]): + for i in T.serial(16): + A[0] = A[0] + i * i + + +class TestIntroduceAdditionOfKnownZeroInBuffer(BaseBeforeAfter): + """Insert a conditionally no-op statement + + Proving that the overcompute is a no-op may use known values that + are present in a buffer. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(i < 14 or A[i] == 0)) + + B[0] = 0 + for i in T.serial(16): + if i < 14: + B[0] = B[0] + A[i] + + def expected(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(i < 14 or A[i] == 0)) + + B[0] = 0 + for i in T.serial(16): + B[0] = B[0] + A[i] + + +class TestIntroduceOverwrittenWrite(BaseBeforeAfter): + """Insert a write that is later overwritten. + + Given two sequential writes to the same location without a read + occurring in-between, the first is a no-op. Therefore, the + conditional in the first loop can be removed, with any temporary + values overwritten by the second loop. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 14: + A[i] = 1 + + for i in T.serial(16): + if i >= 14: + A[i] = 2 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 1 + + for i in T.serial(16): + if i >= 14: + A[i] = 2 + + +class TestMaintainValuesUsedLater(BaseBeforeAfter): + """Do not insert writes that would be used later. + + As TestIntroduceOverwrittenWrite, except that the values stored at + A[14] and A[15] are used by the second loop. Overwriting them in + the first loop would change the result, so the overcompute would + not be valid. + """ + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 14: + A[i] = 1 + + for i in T.serial(16): + if i >= 14: + A[i] = A[i] + 1 + + expected = before + + +class TestIdentifyOverwrittenWriteFromEquivalentExpressions(BaseBeforeAfter): + """Insert a write that is later overwritten. + + As TestIntroduceOverwrittenWrite, but the conditionals used in the + first and second loop have different structures while referring to + the same elements. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 14: + A[i] = 1 + + for io, ii in T.grid(4, 4): + if io == 3 and ii >= 2: + A[4 * io + ii] = 2 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 1 + + for io, ii in T.grid(4, 4): + if io == 3 and ii >= 2: + A[4 * io + ii] = 2 + + +class TestIntroduceSupersetOverwrittenWrite(BaseBeforeAfter): + """Insert a write that is later overwritten. + + As TestIntroduceOverwrittenWrite, but the elements written in the + second loop are not distinct from the elements in the first loop. + So long as the writes introduced by overcompute in the first loop + are a subset of the writes present in the second loop, the + overcompute can be introduced. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 14: + A[i] = 1 + + for i in T.serial(16): + if i >= 14: + A[i] = 2 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 1 + + for i in T.serial(16): + if i >= 14: + A[i] = 2 + + +if __name__ == "__main__": + tvm.testing.main()