From 1a7c98642935d297b8a9174b0fa7b73fff2d5b41 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Thu, 18 Jul 2024 01:05:02 -0500 Subject: [PATCH 1/9] Implementation to eliminate redundant branch introduced due to operator padding and overcompute, this creates more opportunities to vectorize the code --- include/tvm/tir/transform.h | 8 + python/tvm/tir/transform/transform.py | 11 + .../using_assume_to_reduce_branches.cc | 389 +++++++++++ ...nate_pad_branch_using_buffer_assumption.py | 654 ++++++++++++++++++ 4 files changed, 1062 insertions(+) create mode 100644 src/tir/transforms/using_assume_to_reduce_branches.cc create mode 100644 tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 98edbeaceb26..fccdc566a693 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -834,6 +834,14 @@ TVM_DLL Pass InstrumentProfileIntrinsics(); */ TVM_DLL Pass DefaultGPUSchedule(); +/*! + * \brief This pass analyzes primfunc and eliminates branch introdued due to layout specific padding. + * It leverages from the buffer assumptions and use the information to eliminate the branch. + * \note This creates more opportunity to vectorize the code. + * \return The Pass. + */ +TVM_DLL Pass UseAssumeToReduceBranches(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c2022b918643..cf37deaec1da 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1199,3 +1199,14 @@ def DefaultGPUSchedule(): ret: tvm.transform.Pass """ return _ffi_api.DefaultGPUSchedule() # type: ignore + +def UseAssumeToReduceBranches(): + """This pass attempts to eliminates layout specific pad branch by overcomputing the values for padded region. + Eliminating the branch will help to vectorize the code and improve element wise ops performance. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.UseAssumeToReduceBranches() # type: ignore diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc new file mode 100644 index 000000000000..99cff10fc95e --- /dev/null +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file using_assume_to_reduce_branches.cc + * + * \brief Attempt to remove conditional branch statements by introducing + * extra computations that do not impact the final results. Mainly + * oriented for layout specific padding related branches. + * + * \note + * 1. This pass works if the buffer assumption variable is in the branch statement. + * In case, the buffer assumption is not present in the branch statement and + * there are intermediate buffers then, inline the code. + * 2. The assumptions leveraged here should be of the form T.assume(condition_on_indices or + * buffer_equals_to_some_value) + * 3. Some part of the code are reused from the control_flow_graph.cc file which also + * handles eliminating branches in particular scenarios. + * 4. This pass currently works for op_pattern kElemWise and kBroadcast. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../arith/constraint_extract.h" +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/unwrap_vector_expr.h" +#include "simplify.h" +#include "tvm/ir/expr.h" +namespace tvm { +namespace tir { + +using namespace arith; + +class AssumeChecker : public StmtExprVisitor { + /* This class checks if the primfunc has assume statement. + If yes, then only the FuncAnanlyzerMutator class runs. This is to ensure speedup in the pass.*/ + public: + bool has_assume = false; + + void VisitStmt(const Stmt& stmt) final { + if (has_assume) { + return; + } + StmtVisitor::VisitStmt(stmt); + } + void VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::assume())) { + has_assume = true; + } + } +}; + +class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { + /* This class analyzes the complete primfunc. + It parses the buffer assumptions and eliminates the redundant branch + introduced due to layout specific padding by leveraging from buffer assumptions. + On eliminating the branch there are more opportunities to vectorize the code and improve performance. + + Example: + ------------- + Prim Func Before : + for (...) + T.assume( assume_condition or A[i] == 0 ) + for (...) + out = T.if_then_else(if_then_else_condition, 0, function(A)) # here function(A) is some function on Var A + + Prim Func After : + for (...) + T.assume( assume_condition or A[i] == 0 ) + for (...) + out = function(A) # here function(A) is some function on the Var A + -------------- + # High-level implementation details : + 1. The pass parses the assume statement and stores the relevant information. + 2. The pass tries to evaluate the then_clause and else_clause in then_condition_context and else_condition_context. + It checks if the context of the assume statement (for condition indices and + assume_condition) is same as the context of the if_then_else statement (for condition indices and + if_then_else condition). If context is same and the expression inside if_then_else statement is a function of the + buffer assumption (eg A in above example), then the pass substitutes the value from the buffer assumption and + simplifies the expression . + 3. The pass then checks if then_clause and else_clause evaluate to same value. + If yes, then return the else_clause if we are in the then_condition_context (since then_clause + will be true in this context and if else_clause is also evaluating to true then we can directly + replace it with else_clause), similarly, we return the then_clause if we are in the + else_condition_context. + This class handles all these scenarios.*/ + + public: + using Parent = IRMutatorWithAnalyzer; + explicit ParseAssumeAndOvercompute(Analyzer* analyzer) : Parent(analyzer) {} + + private: + using Parent::VisitExpr_; + using Parent::VisitStmt; + using Parent::VisitStmt_; + + // This struct stores all the relevant data related to asssume statement + struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) + PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) + PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding + // bufferload expression (A[i] == 0) + tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 + PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 + Array buffer_indices; // Storing the indices of the buffer Eg : i + }; + // List of conditions in a scope + std::vector conditions_; + + // Storing all the buffer assumptions data in map + std::map map_buffer_assumption; + tir::Buffer current_bufferstorenode_name; + + struct InternalConstraintContext { + /* This stuct appends the constraint passed to it in the conditions list. + It keeps track of the bounds of the variables along with any conditions on the variables */ + InternalConstraintContext(ParseAssumeAndOvercompute* self, PrimExpr constraint) + : self(self), analyzer_context(self->analyzer_, constraint) { + old_num_constraints = self->conditions_.size(); + + auto side_effect = tir::SideEffect(constraint); + if (side_effect <= tir::CallEffectKind::kPure) { + self->conditions_.push_back(constraint); + } else if (side_effect <= tir::CallEffectKind::kReadState) { + assume = constraint; + } + + new_num_constraints = self->conditions_.size(); + } + + ~InternalConstraintContext() { + ICHECK_EQ(self->conditions_.size(), new_num_constraints) + << "Internal error: Each condition should only be popped once."; + self->conditions_.erase(self->conditions_.begin() + old_num_constraints, + self->conditions_.end()); + } + + ParseAssumeAndOvercompute* self{nullptr}; + With analyzer_context; + size_t old_num_constraints{0}; + size_t new_num_constraints{0}; + Optional assume{NullOpt}; + + // Disable default-generated copy/move assignment and constructors + InternalConstraintContext(const InternalConstraintContext&) = delete; + InternalConstraintContext& operator=(const InternalConstraintContext&) = delete; + InternalConstraintContext(InternalConstraintContext&&) = delete; + InternalConstraintContext& operator=(InternalConstraintContext&&) = delete; + }; + + PrimExpr CurrentScopePredicate() const { + /* This combines all the constraints in a scope */ + PrimExpr predicate = Bool(true); + for (const auto& condition : conditions_) { + predicate = predicate && condition; + } + return predicate; + } + + Stmt VisitStmt_(const ForNode* op) final { + /* Create and delete the scope with bind. + Add the minimum and maximum bound for the variables to the conditions_ list using + InternalConstraintContext */ + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + InternalConstraintContext ctx1(this, op->loop_var >= op->min); + InternalConstraintContext ctx2(this, op->loop_var < op->min + op->extent); + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + if(map_buffer_assumption.find(op->buffer) != map_buffer_assumption.end()){ + PrimExpr buf_value; + /* If the cuurent context where the buffer load is present is same as + the context of the buffer assumption then, return the buffer value present in the assumption. + This will eventually replace the bufferload value in the complete expresison */ + + auto buffer_assumption = map_buffer_assumption[op->buffer]; + PrimExpr current_predicate_and_context = CurrentScopePredicate(); + PrimExpr buffer_predicate_and_context = + buffer_assumption.buffer_context && buffer_assumption.buffer_predicate; + bool current_context_and_buffer_constraint_is_same = StructuralEqual()( + current_predicate_and_context, buffer_predicate_and_context, /*map_free_vars=*/true); + + if (current_context_and_buffer_constraint_is_same) { + buf_value = buffer_assumption.buffer_value; + return buf_value; + } + } + return GetRef(op); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(Parent::VisitStmt_(op)); + + // Eliminate the builtin if_then_else statement + if (auto* call = op->value.as()) { + if (call->op.same_as(builtin::if_then_else())) { + PrimExpr cond = call->args[0]; + PrimExpr then_clause = call->args[1]; + PrimExpr else_clause = call->args[2]; + + PrimExpr then_clause_in_then_context; + PrimExpr else_clause_in_then_context; + PrimExpr then_clause_in_else_context; + PrimExpr else_clause_in_else_context; + { + // Simplifying expressions in " then context " + InternalConstraintContext then_ctx(this, cond); + // This will call the current class's appropriate VisitStmt function + then_clause_in_then_context = (*this)(then_clause); + then_clause_in_then_context = analyzer_->Simplify(then_clause_in_then_context); + + else_clause_in_then_context = (*this)(else_clause); + else_clause_in_then_context = analyzer_->Simplify(else_clause_in_then_context); + } + { + // Simplifying expressions in " else context " + InternalConstraintContext else_ctx(this, !cond); + // This will call the current class's appropriate VisitStmt function + then_clause_in_else_context = (*this)(then_clause); + then_clause_in_else_context = analyzer_->Simplify(then_clause_in_else_context); + + else_clause_in_else_context = (*this)(else_clause); + else_clause_in_else_context = analyzer_->Simplify(else_clause_in_else_context); + } + + auto n = this->CopyOnWrite(op); + if (StructuralEqual()(then_clause_in_then_context, else_clause_in_then_context)) { + n->value = analyzer_->Simplify(else_clause); + return Stmt(n); + } else if (StructuralEqual()(then_clause_in_else_context, else_clause_in_else_context)) { + n->value = analyzer_->Simplify(then_clause); + return Stmt(n); + } else { + return Parent::VisitStmt_(op); + } + } + } + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::assume())) { + Assume(op->args[0]); + } + return Parent::VisitExpr_(op); + } + + void Assume(PrimExpr assumption) { + for (const auto& expr : arith::ExtractConstraints(assumption, false)) { + AssumeConstraintComponent(expr); + } + } + + void AssumeConstraintComponent(PrimExpr assumption) { + PrimExpr additional_predicate = Bool(true); + assume_struct buf_data; + + std::vector buffer_exprs; + for (const auto& expr : arith::ExtractComponents(assumption)) { + auto side_effect = tir::SideEffect(expr); + if (side_effect <= tir::CallEffectKind::kPure) { + // Pulling out portions of the assumption that do not depend + // on a buffer value allows the following two forms to be + // treated identically. + // + // Option 1: if i < 3: T.assume(buf[i] == value) + // Option 2: T.assume(i>=3 or buf[i] == value) + additional_predicate = additional_predicate && logical_not(expr); + } else if (side_effect == tir::CallEffectKind::kReadState) { + buffer_exprs.push_back(expr); + } else { + LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr + << " with side-effect \'" << side_effect << "\'"; + } + } + + additional_predicate = analyzer_->Simplify(std::move(additional_predicate)); + CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; + + auto* as_equal_node = buffer_exprs[0].as(); + CHECK(as_equal_node) << "T.assume buffer constraint must be of the form 'buffer[indices] == " + "value', but received " + << assumption; + if (!as_equal_node) { + // This assumption is an inequality on a data-dependent + // conditional. Not an error for this to occur, but also not + // something that is currently supported. + return; + } + + // Parse the statement and store the desired values + // Ex: A[i]==0, load = A[i], value = 0 + tir::BufferLoad load; + PrimExpr value; + if (auto opt = as_equal_node->a.as()) { + load = opt.value(); + value = as_equal_node->b; + } else if (auto opt = as_equal_node->b.as()) { + load = opt.value(); + value = as_equal_node->a; + } else { + LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; + } + + // Populating the assume statement predicate, buffer, value + // and the context of the assume statement + buf_data.buffer_context = CurrentScopePredicate(); + buf_data.buffer_predicate = additional_predicate; + buf_data.buffer_load = load; + buf_data.buffer_value = value; + buf_data.buffer_indices = load->indices; + for (size_t i = 0; i < load->indices.size(); i++) { + buf_data.buffer_indices.push_back(analyzer_->Simplify(load->indices[i])); + } + map_buffer_assumption[buf_data.buffer_load->buffer] = buf_data; + + auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; + CHECK(!has_side_effect) << "Buffer value in constraint must be pure expression, but was " + << value; + if (has_side_effect) { + return; + } + } +}; + +namespace transform { + +Pass UseAssumeToReduceBranches() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + + // The pass runs & eliminates pad branch with overcompute only if, the primfunc has op_pattern defined and is an elementwise op. + // AnnotateTIROpPattern pass will help to set the op_pattern in the op attributes of the primfunc. + if (n->attrs.GetAttr("op_pattern").defined()) { + Optional opt_pattern = f->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + relay::OpPatternKind pattern = static_cast(Downcast(opt_pattern)->value); + + if (pattern == relay::OpPatternKind::kElemWise or + pattern == relay::OpPatternKind::kBroadcast) { + // If the primfunc contains assume statement then, run the mutator pass. + AssumeChecker assume_checker; + assume_checker(std::move(n->body)); + + if (assume_checker.has_assume) { + // Leverage from assume and eliminate the branch + ParseAssumeAndOvercompute func_analyzer_mutator(&analyzer); + n->body = func_analyzer_mutator(std::move(n->body)); + } + } + } + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") + .set_body_typed(UseAssumeToReduceBranches); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py new file mode 100644 index 000000000000..374503d215e7 --- /dev/null +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -0,0 +1,654 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This test runs the reduce_pad_branch_through_over_compute test to check if we are able to eliminate the redundant pad branch and overcompute the value. +# This helps to expose more opportunities to vectorize the code. + +import tvm +import tvm.testing +from tvm import relax + +import tvm.script +from tvm.script import tir as T, relax as R + + +@tvm.script.ir_module +class Add_PrimFunc_Before: + @T.prim_func(private=True) + def add( + A: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + B: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "add", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + Add_PrimFunc_Before.add, + (A, B), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class Add_PrimFunc_Expected: + @T.prim_func(private=True) + def add( + A: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + B: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "add", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + Add_PrimFunc_Expected.add, + (A, B), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class Sub_PrimFunc_Before: + @T.prim_func(private=True) + def sub( + A: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + B: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "sub", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + Sub_PrimFunc_Before.sub, + (A, B), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class Sub_PrimFunc_Expected: + @T.prim_func(private=True) + def sub( + A: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + B: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "sub", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + Sub_PrimFunc_Expected.sub, + (A, B), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class Mul_PrimFunc_Before: + @T.prim_func(private=True) + def mul( + A: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + B: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "mul", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + Mul_PrimFunc_Before.mul, + (A, B), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class Mul_PrimFunc_Expected: + @T.prim_func(private=True) + def mul( + A: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + B: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "mul", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + Mul_PrimFunc_Expected.mul, + (A, B), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +def test_add_primfunc_overcompute(): + Add_PrimFunc_After = tvm.tir.transform.UseAssumeToReduceBranches()(Add_PrimFunc_Before) + tvm.ir.structural_equal( + Add_PrimFunc_After["add"], Add_PrimFunc_Expected["add"], map_free_vars=True + ) + + +def test_sub_primfunc_overcompute(): + Sub_PrimFunc_After = tvm.tir.transform.UseAssumeToReduceBranches()(Sub_PrimFunc_Before) + tvm.ir.structural_equal( + Sub_PrimFunc_After["sub"], Sub_PrimFunc_Expected["sub"], map_free_vars=True + ) + + +def test_mul_primfunc_overcompute(): + Mul_PrimFunc_After = tvm.tir.transform.UseAssumeToReduceBranches()(Mul_PrimFunc_Before) + tvm.ir.structural_equal( + Mul_PrimFunc_After["mul"], Mul_PrimFunc_Expected["mul"], map_free_vars=True + ) + + +if __name__ == "__main__": + tvm.testing.main() From ee704fa022606fb1dca3df75f02496356b00cb29 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Thu, 18 Jul 2024 01:24:51 -0500 Subject: [PATCH 2/9] Fixed lint error in transform.py file --- python/tvm/tir/transform/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index cf37deaec1da..5204bd7acc50 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1200,10 +1200,11 @@ def DefaultGPUSchedule(): """ return _ffi_api.DefaultGPUSchedule() # type: ignore + def UseAssumeToReduceBranches(): """This pass attempts to eliminates layout specific pad branch by overcomputing the values for padded region. Eliminating the branch will help to vectorize the code and improve element wise ops performance. - + Returns ------- fpass : tvm.transform.Pass From 981717bb7193035c6b879d4ae17ee3d842185bb9 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Thu, 18 Jul 2024 11:44:49 -0500 Subject: [PATCH 3/9] Fixed lint errors in the file using_assume_to_reduce_branches.cc --- .../using_assume_to_reduce_branches.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 99cff10fc95e..95e37ea8327b 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -25,9 +25,9 @@ * oriented for layout specific padding related branches. * * \note - * 1. This pass works if the buffer assumption variable is in the branch statement. - * In case, the buffer assumption is not present in the branch statement and - * there are intermediate buffers then, inline the code. + * 1. This pass works if the buffer assumption variable is in the branch statement. + * In case, the buffer assumption is not present in the branch statement and + * there are intermediate buffers then, inline the code. * 2. The assumptions leveraged here should be of the form T.assume(condition_on_indices or * buffer_equals_to_some_value) * 3. Some part of the code are reused from the control_flow_graph.cc file which also @@ -76,10 +76,10 @@ class AssumeChecker : public StmtExprVisitor { class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { /* This class analyzes the complete primfunc. - It parses the buffer assumptions and eliminates the redundant branch - introduced due to layout specific padding by leveraging from buffer assumptions. + It parses the buffer assumptions and eliminates the redundant branch + introduced due to layout specific padding by leveraging from buffer assumptions. On eliminating the branch there are more opportunities to vectorize the code and improve performance. - + Example: ------------- Prim Func Before : @@ -87,7 +87,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { T.assume( assume_condition or A[i] == 0 ) for (...) out = T.if_then_else(if_then_else_condition, 0, function(A)) # here function(A) is some function on Var A - + Prim Func After : for (...) T.assume( assume_condition or A[i] == 0 ) @@ -100,7 +100,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { It checks if the context of the assume statement (for condition indices and assume_condition) is same as the context of the if_then_else statement (for condition indices and if_then_else condition). If context is same and the expression inside if_then_else statement is a function of the - buffer assumption (eg A in above example), then the pass substitutes the value from the buffer assumption and + buffer assumption (eg A in above example), then the pass substitutes the value from the buffer assumption and simplifies the expression . 3. The pass then checks if then_clause and else_clause evaluate to same value. If yes, then return the else_clause if we are in the then_condition_context (since then_clause From bc41d235557d3debdf696c862a51d76f995f4e1e Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Thu, 18 Jul 2024 13:47:26 -0500 Subject: [PATCH 4/9] Fixed lint error in transform.py related to line too long --- python/tvm/tir/transform/transform.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 5204bd7acc50..d8531401d49d 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1202,8 +1202,9 @@ def DefaultGPUSchedule(): def UseAssumeToReduceBranches(): - """This pass attempts to eliminates layout specific pad branch by overcomputing the values for padded region. - Eliminating the branch will help to vectorize the code and improve element wise ops performance. + """This pass attempts to eliminates layout specific pad branch by overcomputing the values + for padded region. Eliminating the branch will help to vectorize code, + and improve element wise ops performance. Returns ------- From cc5cdcc076b017fdf03e9134d8917026b8ff3061 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Mon, 22 Jul 2024 00:18:42 -0500 Subject: [PATCH 5/9] Fixed Lint error related to space and length of the sentence in using_assume_to_reduce_branches.cc --- .../transforms/using_assume_to_reduce_branches.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 95e37ea8327b..996746bc99f2 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -191,7 +191,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const BufferLoadNode* op) override { - if(map_buffer_assumption.find(op->buffer) != map_buffer_assumption.end()){ + if (map_buffer_assumption.find(op->buffer) != map_buffer_assumption.end()) { PrimExpr buf_value; /* If the cuurent context where the buffer load is present is same as the context of the buffer assumption then, return the buffer value present in the assumption. @@ -354,14 +354,16 @@ Pass UseAssumeToReduceBranches() { auto* n = f.CopyOnWrite(); arith::Analyzer analyzer; - // The pass runs & eliminates pad branch with overcompute only if, the primfunc has op_pattern defined and is an elementwise op. - // AnnotateTIROpPattern pass will help to set the op_pattern in the op attributes of the primfunc. + // The pass runs & eliminates pad branch with overcompute only if, + // the primfunc has op_pattern defined and is an elementwise op. + // AnnotateTIROpPattern pass will set op_pattern in op attributes of the primfunc. if (n->attrs.GetAttr("op_pattern").defined()) { Optional opt_pattern = f->GetAttr("op_pattern"); if (opt_pattern.defined()) { - relay::OpPatternKind pattern = static_cast(Downcast(opt_pattern)->value); + relay::OpPatternKind pattern; + pattern = static_cast(Downcast(opt_pattern)->value); - if (pattern == relay::OpPatternKind::kElemWise or + if (pattern == relay::OpPatternKind::kElemWise || pattern == relay::OpPatternKind::kBroadcast) { // If the primfunc contains assume statement then, run the mutator pass. AssumeChecker assume_checker; From 2bca75128f67ef7c5921030e9068623b57f792cf Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Mon, 22 Jul 2024 00:38:11 -0500 Subject: [PATCH 6/9] Fixed lint error : trailing whitespaces in using_assume_to_reduce_breanches.cc --- src/tir/transforms/using_assume_to_reduce_branches.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 996746bc99f2..8b3d08da082c 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -354,7 +354,7 @@ Pass UseAssumeToReduceBranches() { auto* n = f.CopyOnWrite(); arith::Analyzer analyzer; - // The pass runs & eliminates pad branch with overcompute only if, + // The pass runs & eliminates pad branch with overcompute only if, // the primfunc has op_pattern defined and is an elementwise op. // AnnotateTIROpPattern pass will set op_pattern in op attributes of the primfunc. if (n->attrs.GetAttr("op_pattern").defined()) { From 35346232f4babe0326c45f15d2bf969edccfba05 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Mon, 22 Jul 2024 01:21:01 -0500 Subject: [PATCH 7/9] Fixed lint error: clang format issue in cpp files --- .../using_assume_to_reduce_branches.cc | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 8b3d08da082c..4df602457bec 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -35,6 +35,7 @@ * 4. This pass currently works for op_pattern kElemWise and kBroadcast. */ +#include #include #include #include @@ -42,7 +43,7 @@ #include #include #include -#include + #include #include "../../arith/constraint_extract.h" @@ -78,7 +79,8 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { /* This class analyzes the complete primfunc. It parses the buffer assumptions and eliminates the redundant branch introduced due to layout specific padding by leveraging from buffer assumptions. - On eliminating the branch there are more opportunities to vectorize the code and improve performance. + On eliminating the branch there are more opportunities to vectorize the code + and improve performance. Example: ------------- @@ -86,7 +88,8 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { for (...) T.assume( assume_condition or A[i] == 0 ) for (...) - out = T.if_then_else(if_then_else_condition, 0, function(A)) # here function(A) is some function on Var A + out = T.if_then_else(if_then_else_condition, 0, function(A)) + # here function(A) is some function on Var A Prim Func After : for (...) @@ -96,12 +99,13 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { -------------- # High-level implementation details : 1. The pass parses the assume statement and stores the relevant information. - 2. The pass tries to evaluate the then_clause and else_clause in then_condition_context and else_condition_context. + 2. The pass tries to evaluate the then_clause and else_clause in then_condition_context + and else_condition_context. It checks if the context of the assume statement (for condition indices and - assume_condition) is same as the context of the if_then_else statement (for condition indices and - if_then_else condition). If context is same and the expression inside if_then_else statement is a function of the - buffer assumption (eg A in above example), then the pass substitutes the value from the buffer assumption and - simplifies the expression . + assume_condition) is same as the context of the if_then_else statement (for condition indices + and if_then_else condition). If context is same and the expression inside if_then_else statement + is a function of the buffer assumption (eg A in above example), + then the pass substitutes the value from the buffer assumption and simplifies the expression. 3. The pass then checks if then_clause and else_clause evaluate to same value. If yes, then return the else_clause if we are in the then_condition_context (since then_clause will be true in this context and if else_clause is also evaluating to true then we can directly From d7c7e22ea3a2ac58627c7cf80b62ff8e6a769df9 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Mon, 22 Jul 2024 20:24:09 -0500 Subject: [PATCH 8/9] fixed pylint errors in python files and used clang format to format the cpp files --- include/tvm/tir/transform.h | 2 +- .../using_assume_to_reduce_branches.cc | 2 +- ...nate_pad_branch_using_buffer_assumption.py | 196 +++++++++--------- 3 files changed, 100 insertions(+), 100 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index fccdc566a693..a8d93bf898c4 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -835,7 +835,7 @@ TVM_DLL Pass InstrumentProfileIntrinsics(); TVM_DLL Pass DefaultGPUSchedule(); /*! - * \brief This pass analyzes primfunc and eliminates branch introdued due to layout specific padding. + * \brief This pass analyzes primfunc & eliminates branch introdued due to layout specific padding. * It leverages from the buffer assumptions and use the information to eliminate the branch. * \note This creates more opportunity to vectorize the code. * \return The Pass. diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 4df602457bec..766a1d7783b0 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -88,7 +88,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { for (...) T.assume( assume_condition or A[i] == 0 ) for (...) - out = T.if_then_else(if_then_else_condition, 0, function(A)) + out = T.if_then_else(if_then_else_condition, 0, function(A)) # here function(A) is some function on Var A Prim Func After : diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py index 374503d215e7..7221644e8802 100644 --- a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -14,27 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-docstring, unused-variable -# This test runs the reduce_pad_branch_through_over_compute test to check if we are able to eliminate the redundant pad branch and overcompute the value. +# The test attempts to eliminate redundant pad branch and overcompute the value for elementwise ops. # This helps to expose more opportunities to vectorize the code. import tvm import tvm.testing -from tvm import relax import tvm.script from tvm.script import tir as T, relax as R @tvm.script.ir_module -class Add_PrimFunc_Before: +class AddBefore: @T.prim_func(private=True) def add( - A: T.Buffer( + a: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), - B: T.Buffer( + b: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), @@ -59,7 +59,7 @@ def add( v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( not ( @@ -68,7 +68,7 @@ def add( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5 ) - or A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -79,7 +79,7 @@ def add( v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( not ( @@ -88,7 +88,7 @@ def add( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5 ) - or B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -100,8 +100,8 @@ def add( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) T.reads( - A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], - B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) compute[ @@ -112,32 +112,32 @@ def add( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5, T.uint8(0), - A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] - + B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) @R.function def main( - A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), - B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): out = R.call_tir( - Add_PrimFunc_Before.add, - (A, B), + AddBefore.add, + (a, b), out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @tvm.script.ir_module -class Add_PrimFunc_Expected: +class AddExpected: @T.prim_func(private=True) def add( - A: T.Buffer( + a: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), - B: T.Buffer( + b: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), @@ -163,12 +163,12 @@ def add( v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) - or A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -180,12 +180,12 @@ def add( v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) - or B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -203,39 +203,39 @@ def add( ) v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) T.reads( - A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], - B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) T.writes( compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] ) compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( - A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] - + B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] ) @R.function def main( - A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), - B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): out = R.call_tir( - Add_PrimFunc_Expected.add, - (A, B), + AddExpected.add, + (a, b), out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @tvm.script.ir_module -class Sub_PrimFunc_Before: +class SubBefore: @T.prim_func(private=True) def sub( - A: T.Buffer( + a: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), - B: T.Buffer( + b: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), @@ -260,7 +260,7 @@ def sub( v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( not ( @@ -269,7 +269,7 @@ def sub( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5 ) - or A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -280,7 +280,7 @@ def sub( v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( not ( @@ -289,7 +289,7 @@ def sub( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5 ) - or B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -301,8 +301,8 @@ def sub( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) T.reads( - A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], - B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) compute[ @@ -313,32 +313,32 @@ def sub( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5, T.uint8(0), - A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] - - B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) @R.function def main( - A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), - B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): out = R.call_tir( - Sub_PrimFunc_Before.sub, - (A, B), + SubBefore.sub, + (a, b), out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @tvm.script.ir_module -class Sub_PrimFunc_Expected: +class SubExpected: @T.prim_func(private=True) def sub( - A: T.Buffer( + a: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), - B: T.Buffer( + b: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), @@ -364,12 +364,12 @@ def sub( v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) - or A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -381,12 +381,12 @@ def sub( v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) - or B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -404,39 +404,39 @@ def sub( ) v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) T.reads( - A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], - B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) T.writes( compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] ) compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( - A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] - - B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] ) @R.function def main( - A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), - B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): out = R.call_tir( - Sub_PrimFunc_Expected.sub, - (A, B), + SubExpected.sub, + (a, b), out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @tvm.script.ir_module -class Mul_PrimFunc_Before: +class MulBefore: @T.prim_func(private=True) def mul( - A: T.Buffer( + a: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), - B: T.Buffer( + b: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), @@ -461,7 +461,7 @@ def mul( v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( not ( @@ -470,7 +470,7 @@ def mul( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5 ) - or A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -481,7 +481,7 @@ def mul( v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( not ( @@ -490,7 +490,7 @@ def mul( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5 ) - or B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -502,8 +502,8 @@ def mul( "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] ) T.reads( - A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], - B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) compute[ @@ -514,32 +514,32 @@ def mul( or v_axis2 == T.int64(3) and T.int64(4) <= v_axis5, T.uint8(0), - A[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] - * B[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) @R.function def main( - A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), - B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): out = R.call_tir( - Mul_PrimFunc_Before.mul, - (A, B), + MulBefore.mul, + (a, b), out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out @tvm.script.ir_module -class Mul_PrimFunc_Expected: +class MulExpected: @T.prim_func(private=True) def mul( - A: T.Buffer( + a: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), - B: T.Buffer( + b: T.Buffer( (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), "uint8", ), @@ -565,12 +565,12 @@ def mul( v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) - or A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -582,12 +582,12 @@ def mul( v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] ) - T.reads(B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) T.writes() T.assume( (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) - or B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] == T.uint8(0) ) @@ -605,48 +605,48 @@ def mul( ) v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) T.reads( - A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], - B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], ) T.writes( compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] ) compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( - A[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] - * B[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] ) @R.function def main( - A: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), - B: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): out = R.call_tir( - Mul_PrimFunc_Expected.mul, - (A, B), + MulExpected.mul, + (a, b), out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), ) return out def test_add_primfunc_overcompute(): - Add_PrimFunc_After = tvm.tir.transform.UseAssumeToReduceBranches()(Add_PrimFunc_Before) + add_after = tvm.tir.transform.UseAssumeToReduceBranches()(AddBefore) tvm.ir.structural_equal( - Add_PrimFunc_After["add"], Add_PrimFunc_Expected["add"], map_free_vars=True + add_after["add"], AddExpected["add"], map_free_vars=True ) def test_sub_primfunc_overcompute(): - Sub_PrimFunc_After = tvm.tir.transform.UseAssumeToReduceBranches()(Sub_PrimFunc_Before) + sub_after = tvm.tir.transform.UseAssumeToReduceBranches()(SubBefore) tvm.ir.structural_equal( - Sub_PrimFunc_After["sub"], Sub_PrimFunc_Expected["sub"], map_free_vars=True + sub_after["sub"], SubExpected["sub"], map_free_vars=True ) def test_mul_primfunc_overcompute(): - Mul_PrimFunc_After = tvm.tir.transform.UseAssumeToReduceBranches()(Mul_PrimFunc_Before) + mul_after = tvm.tir.transform.UseAssumeToReduceBranches()(MulBefore) tvm.ir.structural_equal( - Mul_PrimFunc_After["mul"], Mul_PrimFunc_Expected["mul"], map_free_vars=True + mul_after["mul"], MulExpected["mul"], map_free_vars=True ) From 50b44f5d1a6375ebb5fba7729258c2507e148014 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Mon, 22 Jul 2024 21:01:39 -0500 Subject: [PATCH 9/9] Ran black format and removed the attr_registry_map.h import as it was running into some other issue because of which build was failing --- .../transforms/using_assume_to_reduce_branches.cc | 1 - ...t_eliminate_pad_branch_using_buffer_assumption.py | 12 +++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 766a1d7783b0..2e45bb0ff8fb 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -35,7 +35,6 @@ * 4. This pass currently works for op_pattern kElemWise and kBroadcast. */ -#include #include #include #include diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py index 7221644e8802..b8ff2b6c79b2 100644 --- a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -631,23 +631,17 @@ def main( def test_add_primfunc_overcompute(): add_after = tvm.tir.transform.UseAssumeToReduceBranches()(AddBefore) - tvm.ir.structural_equal( - add_after["add"], AddExpected["add"], map_free_vars=True - ) + tvm.ir.structural_equal(add_after["add"], AddExpected["add"], map_free_vars=True) def test_sub_primfunc_overcompute(): sub_after = tvm.tir.transform.UseAssumeToReduceBranches()(SubBefore) - tvm.ir.structural_equal( - sub_after["sub"], SubExpected["sub"], map_free_vars=True - ) + tvm.ir.structural_equal(sub_after["sub"], SubExpected["sub"], map_free_vars=True) def test_mul_primfunc_overcompute(): mul_after = tvm.tir.transform.UseAssumeToReduceBranches()(MulBefore) - tvm.ir.structural_equal( - mul_after["mul"], MulExpected["mul"], map_free_vars=True - ) + tvm.ir.structural_equal(mul_after["mul"], MulExpected["mul"], map_free_vars=True) if __name__ == "__main__":