Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ def UnrollLoop():
return _ffi_api.UnrollLoop() # type: ignore


def ReduceBranchingThroughOvercompute():
"""Reduce branching by introducing overcompute

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ReduceBranchingThroughOvercompute() # type: ignore


def RemoveNoOp():
"""Remove No Op from the Stmt.

Expand Down
10 changes: 7 additions & 3 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1388,8 +1388,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
EQ ret = Downcast<EQ>(IRMutatorWithAnalyzer::VisitExpr_(op));
op = ret.get();

if (auto const_res = TryConstFold<EQ>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
if (auto const_res = TryConstFold<EQ>(op->a, op->b)) {
return const_res.value();
}
if (auto match = TryMatchLiteralConstraint(ret)) {
return match.value();
}

return ApplyRewriteRules(ret);
}
Expand Down Expand Up @@ -1419,7 +1423,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1);
TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0);
TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0);
}
return std::move(ret);
}
Expand Down
178 changes: 178 additions & 0 deletions src/tir/transforms/reduce_branching_through_overcompute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file reduce_branching_through_overcompute.cc
*
* \brief Attempt to remove conditional statements by introducing
* extra computations that do not impact the final results.
*/

#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>

#include <optional>

#include "../../arith/ir_mutator_with_analyzer.h"
#include "../analysis/control_flow_graph.h"
#include "remove_no_op.h"
#include "simplify.h"

namespace tvm {
namespace tir {

struct ReduceBranchingThroughOvercomputeConfigNode
: public tvm::AttrsNode<ReduceBranchingThroughOvercomputeConfigNode> {
bool use_dataflow_analysis;

TVM_DECLARE_ATTRS(ReduceBranchingThroughOvercomputeConfigNode,
"tir.transform.ReduceBranchingThroughOvercomputeConfig") {
TVM_ATTR_FIELD(use_dataflow_analysis)
.describe(
"If true, known buffer values are propagated and used "
"to statically prove that overcompute is valid.")
.set_default(false);
}
};

class ReduceBranchingThroughOvercomputeConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReduceBranchingThroughOvercomputeConfig, Attrs,
ReduceBranchingThroughOvercomputeConfigNode);
};

TVM_REGISTER_NODE_TYPE(ReduceBranchingThroughOvercomputeConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ReduceBranchingThroughOvercompute",
ReduceBranchingThroughOvercomputeConfig);

struct ElseBranchFiller : StmtExprMutator {
Stmt VisitStmt_(const IfThenElseNode* op) override {
IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op));
if (ret->else_case.defined()) {
return std::move(ret);
} else {
auto new_else_clause = Evaluate(0);
new_else_clauses.insert(new_else_clause);
return IfThenElse(ret->condition, ret->then_case, new_else_clause);
}
}

std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual> new_else_clauses;
};

class ElseBranchStripper : public StmtExprMutator {
public:
ElseBranchStripper(
const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses)
: new_else_clauses_(new_else_clauses) {}

private:
Stmt VisitStmt_(const IfThenElseNode* op) override {
IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op));
auto as_eval = ret->else_case.as<EvaluateNode>();
if (as_eval && new_else_clauses_.count(GetRef<Evaluate>(as_eval))) {
return IfThenElse(ret->condition, ret->then_case);
} else {
return std::move(ret);
}
}

const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses_;
};

class BranchReducer : public arith::IRMutatorWithAnalyzer {
public:
static Stmt Apply(Stmt stmt, const std::optional<ControlFlowGraph>& touch_pattern) {
arith::Analyzer analyzer;
BranchReducer visitor(&analyzer, touch_pattern);
return visitor(std::move(stmt));
}

private:
using Parent = IRMutatorWithAnalyzer;
using Parent::VisitStmt;
using Parent::VisitStmt_;

BranchReducer(arith::Analyzer* analyzer, const std::optional<ControlFlowGraph>& touch_pattern)
: Parent(analyzer), touch_pattern_(touch_pattern) {}

Stmt VisitStmt_(const IfThenElseNode* op) final {
IfThenElse cond = Downcast<IfThenElse>(Parent::VisitStmt_(op));

auto is_special_case = [&](PrimExpr condition, Stmt general_case, Stmt special_case) -> bool {
condition = analyzer_->rewrite_simplify(condition);
With<arith::ConstraintContext> constraint(analyzer_, condition);
Stmt stmt = RemoveNoOp(general_case, analyzer_, touch_pattern_, special_case.get());
return StructuralEqual()(stmt, special_case);
};

ICHECK(cond->else_case.defined() || !touch_pattern_.has_value())
<< "Temp assert, should be true whenever touch pattern is available";
Stmt else_case = cond->else_case.value_or(Evaluate(0));

if (is_special_case(cond->condition, else_case, cond->then_case)) {
return else_case;
} else if (is_special_case(!cond->condition, cond->then_case, else_case)) {
return cond->then_case;
} else {
return std::move(cond);
}
}

private:
const std::optional<ControlFlowGraph>& touch_pattern_;
};

namespace transform {

Pass ReduceBranchingThroughOvercompute() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;

ReduceBranchingThroughOvercomputeConfig config =
ctx->GetConfig<ReduceBranchingThroughOvercomputeConfig>(
"tir.ReduceBranchingThroughOvercompute")
.value_or(AttrsWithDefaultValues<ReduceBranchingThroughOvercomputeConfig>());

auto* n = f.CopyOnWrite();

std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
ElseBranchFiller else_branch_filler;
if (config->use_dataflow_analysis) {
n->body = else_branch_filler(std::move(n->body));
touch_pattern.emplace(n->body);
}

n->body = BranchReducer::Apply(std::move(n->body), touch_pattern);

if (config->use_dataflow_analysis) {
n->body = ElseBranchStripper(else_branch_filler.new_else_clauses)(std::move(n->body));
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {});
}

TVM_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute")
.set_body_typed(ReduceBranchingThroughOvercompute);

} // namespace transform

} // namespace tir
} // namespace tvm
32 changes: 19 additions & 13 deletions src/tir/transforms/remove_no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,21 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
touch_pattern_->RemoveStore(store);
return only_side_effects();
}
}

// A write whose destination is known to already contain the
// values to be written is a no-op.
PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices);

PrimExpr simplified =
touch_pattern_->SimplifyInContext(stores_existing_value, context, analyzer_);
if (auto* as_int = as_const_int(simplified); as_int && *as_int) {
return only_side_effects();
}
// A write whose destination is known to already contain the
// values to be written is a no-op.
// PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices);
PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0;
if (touch_pattern_.has_value()) {
Stmt context_arg = context_ ? GetRef<Stmt>(context_) : Stmt(store);
stores_existing_value =
touch_pattern_->SimplifyInContext(stores_existing_value, context_arg, analyzer_);
} else {
stores_existing_value = analyzer_->Simplify(stores_existing_value);
}
if (is_one(stores_existing_value)) {
return only_side_effects();
}

// If the stored value is a load from the same location, the
Expand Down Expand Up @@ -293,6 +298,11 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
const StmtNode* context_;
};

Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional<ControlFlowGraph> touch_pattern,
const StmtNode* context) {
return NoOpRemover::Apply(std::move(stmt), analyzer, std::move(touch_pattern), context);
}

namespace transform {

Pass RemoveNoOp() {
Expand All @@ -306,10 +316,6 @@ Pass RemoveNoOp() {
}

arith::Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));

auto* n = f.CopyOnWrite();
n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
Expand Down
60 changes: 60 additions & 0 deletions src/tir/transforms/remove_no_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file remove_no_op.h
* \brief Helper functions to construct and compose IR nodes.
*/
#ifndef TVM_TIR_TRANSFORMS_REMOVE_NO_OP_H_
#define TVM_TIR_TRANSFORMS_REMOVE_NO_OP_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt.h>

#include <optional>

#include "../analysis/control_flow_graph.h"

namespace tvm {
namespace tir {

/* \brief Remove no-ops from the statement
*
* Applies the same behavior as the tir.transform.RemoveNoOp pass, but
* on a single statement, usable as a subroutine in other passes.
*
* \param stmt The TIR statement from which to remove no-ops
*
* \param analyzer The analyzer to use while proving no-ops
*
* \param control_flow The analyzed control-flow graph, which contains
* the `stmt` to be analyzed. If provided, known buffer values will
* be used to remove no-ops. (e.g. Removing `buf[i] = 0` in cases
* where `buf[i]` is known to already contain zero.) If nullptr,
* known buffer values will not be used.
*
* \return The modified statement with no-ops removed
*/
Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer,
std::optional<ControlFlowGraph> touch_pattern = std::nullopt,
const StmtNode* context = nullptr);

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_REMOVE_NO_OP_H_
42 changes: 42 additions & 0 deletions src/tir/transforms/simplify.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file simplify.h
* \brief Helper functions to construct and compose IR nodes.
*/
#ifndef TVM_TIR_TRANSFORMS_SIMPLIFY_H_
#define TVM_TIR_TRANSFORMS_SIMPLIFY_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt.h>

namespace tvm {
namespace tir {

/* \brief Simplifies the statement
*
* Applies the same behavior as the tir.transform.Simplify pass, but
* on a single statement, usable as a subroutine in other passes.
*/
Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer);

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_SIMPLIFY_H_
Loading