diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index f8371b1a6176..4cf6f086d1ed 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -285,6 +285,73 @@ class IterSumExpr : public IterMapExpr { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); + +/*! \brief A utility struct for return values from DetectPaddedIterMap + */ +struct PaddedIterMapResult { + // Any errors that occurred while converting the input indices. If + // the array is empty, the conversion was successful. + Array errors; + + // The detected pattern if a match exists. + Array indices; + + /* \brief Boolean expression indicating if padding was required + * + * `requires_padding` evaluates to true if the returned indices + * contain padding relative to the provided expressions, and false + * otherwise. If `input_iters` contains a variable extent, this + * expression may be in terms of those variables. + */ + PrimExpr requires_padding; + + /* \brief Boolean expression indicating if a specific value w + * + * `padding_predicate` evaluates to true for a set of indices that + * are outside the bounds of the provided index iterators, but + * inside the bounds of the returned index iterators. This + * expression is in terms of the variables provided in + * `input_iters`. + */ + PrimExpr padding_predicate; +}; + +/*! + * \brief Detect if indices can be written as + * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] + * + * Here y = some-quasi-affine-iter-map(input_iters) and c are + * symbolic constants. The y_i iterators may be padded to fit this + * representation. + * + * We also requires that y_i and y_j to be independent for i != j. + * + * For returned value rv, the following is always true: + * - rv.indices[i]->args.size() <=1: only one iterator per element. + * + * \param indices The indices to detect pattern for. + * + * \param input_iters Map from variable to iterator's range. + * + * \param predicate The predicate constraints on the input iterators + * + * \param require_bijective A boolean flag that indicates whether the + * mapping should be bijective. If true, no padding may be + * introduced. + * + * \param analyzer Analyzer used to get context information. + * + * \param simplify_trivial_iterators If true, iterators with extent of + * 1 will be replaced with a constant value. + * + * \return An instance of PaddedIterMapResult. + */ +PaddedIterMapResult DetectPaddedIterMap(const Array& indices, + const Map& input_iters, + const PrimExpr& predicate, bool require_bijective, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); + /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * @@ -352,11 +419,11 @@ Array> SubspaceDivide(const Array& bindings, bool require_bijective, arith::Analyzer* analyzer); /*! - * \brief Given an IterMapExpr, transform it to normal PrimExpr. - * \param expr The input IterMapExpr. + * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. + * \param expr The input expression, which may contain IterMapExpr. * \return The corresponding normal PrimExpr. */ -PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr); +PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr); } // namespace arith } // namespace tvm diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 195bf7e02ce3..b6faa67ab53a 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -31,6 +31,8 @@ #include #include +#include + namespace tvm { namespace tir { @@ -141,12 +143,24 @@ class IndexMap : public ObjectRef { * * TODO(Lunderberg): Look into allowing non-bijective * transformations. If injective, the inverse mapping could still - * be generated with some predicate. If non-injective, could - * simplify the implementation of other optimizations (e.g. double - * buffering as a map `lambda *indices: [buffer_loop%2, *indices]`). + * be generated with some predicate (see NonSurjectiveInverse). If + * non-injective, could simplify the implementation of other + * optimizations (e.g. double buffering as a map `lambda *indices: + * [buffer_loop%2, *indices]`). */ IndexMap Inverse(Array initial_ranges) const; + /*! \brief Generate the inverse mapping. + * + * Determine the inverse, where the output range may contain + * addresses that do not correspond to an address in the input + * range. + * + * \return The inverted index map, along with the predicate for + * which the inverse maps to a valid range. + */ + std::pair NonSurjectiveInverse(Array initial_ranges) const; + TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); }; diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 643bbca8eebd..d84513e072d3 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,13 +16,14 @@ # under the License. """Function data types.""" -from typing import Callable, List, Mapping, Optional, Union +from typing import Callable, List, Mapping, Optional, Union, Tuple import inspect +import tvm import tvm._ffi import tvm.runtime from tvm.runtime import Object -from tvm.ir import BaseFunc +from tvm.ir import BaseFunc, Range from .buffer import Buffer from .expr import Var, PrimExpr from . import _ffi_api @@ -296,12 +297,42 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): final_indices = mapping_function(*args) return IndexMap(args, final_indices) + def is_equivalent_to(self, other_map: "IndexMap") -> bool: + """Return if the index maps are equivalent. + + Parameters + ---------- + other_map: IndexMap + + The IndexMap to which the comparison should be made. + + Returns + ------- + is_equivalent: bool + + True if the two mappings represent the same + transformation, otherwise False + """ + if len(self.initial_indices) != len(other_map.initial_indices): + return False + if len(self.final_indices) != len(other_map.final_indices): + return False + + analyzer = tvm.arith.Analyzer() + + mapped_other_final_indices = other_map.map_indices(self.initial_indices) + for self_index, other_index in zip(self.final_indices, mapped_other_final_indices): + if not analyzer.can_prove_equal(self_index, other_index): + return False + + return True + def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]: """Apply the index map to a set of indices Parameters ---------- - indices : List[PriExpr] + indices : List[PrimExpr] The indices to be mapped Returns @@ -310,3 +341,76 @@ def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]: The mapped indices """ return _ffi_api.IndexMapMapIndices(self, indices) + + def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]: + """Apply the index map to a buffer shape + + Parameters + ---------- + shape : List[PrimExpr] + The buffer shape to be mapped + + Returns + ------- + result : List[PrimExpr] + The mapped shape + """ + return _ffi_api.IndexMapMapShape(self, shape) + + def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": + """Return the inverse of the map + + Throws an error if the function is not bijective. + + Parameters + ---------- + shape: List[Union[Range,PrimExpr]] + + The region over which the inverse should be determined. + Used for validating that the mapping is bijective over + this range. + + Returns + ------- + inverse : IndexMap + + The inverse + """ + + shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape] + return _ffi_api.IndexMapInverse(self, shape) + + def non_surjective_inverse( + self, shape: List[Union[Range, PrimExpr]] + ) -> Tuple["IndexMap", PrimExpr]: + """Return the inverse of the map + + Can be applied to transformations that introduce padding. + + Parameters + ---------- + shape: List[Union[Range,PrimExpr]] + + The region over which the inverse should be determined. + Used for determining the predicate. + + Returns + ------- + result : Tuple[IndexMap, PrimExpr] + + The inverse, and a predicate for which the inverse maps to + a valid index in the input range. + + Examples + -------- + + .. code-block:: python + + index_map = IndexMap.from_func(lambda i: [i//4, i%4]) + inverse_map, predicate = index_map.non_surjective_inverse([14]) + assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k]) + print(predicate) # Prints "(axis0==3) && (axis2 >= 2)" + """ + + shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape] + return _ffi_api.IndexMapNonSurjectiveInverse(self, shape) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index ec2680d8e666..a012b6e80c08 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -26,6 +26,9 @@ #include #include #include +#include + +#include #include "../support/utils.h" #include "const_fold.h" @@ -174,8 +177,11 @@ class IterMapRewriter : public ExprMutator { using Parent = ExprMutator; explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, - bool simplify_trivial_iterators) - : analyzer_(analyzer) { + bool simplify_trivial_iterators, Array* errors) + : analyzer_(analyzer), + errors_(*errors), + requires_padding_(const_false()), + padding_predicate_(const_false()) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -195,12 +201,19 @@ class IterMapRewriter : public ExprMutator { } } - size_t unresolved_count() const { return unresolved_count_; } + PrimExpr padding_predicate() const { return padding_predicate_; } + PrimExpr requires_padding() const { return requires_padding_; } IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } + void UpdatePadding(const PrimExpr& expr) { + update_iterator_padding_ = true; + DirectMutate(expr); + update_iterator_padding_ = false; + } + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const Optional& predicate_induced_min, const Optional& predicate_induced_max) { @@ -234,6 +247,7 @@ class IterMapRewriter : public ExprMutator { // All the splits that refers to the iter_mark covers its extent. // The splits do not overlap with each other. collector.Collect(bindings); + for (const IterMark& mark : collector.visited_) { if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { return false; @@ -292,7 +306,9 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr(const PrimExpr& input_expr) final { auto expr = ExprMutator::VisitExpr(input_expr); if (expr->IsInstance()) { - unresolved_count_++; + ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in " + << "IterMapRewriter using DirectMutate. " + << "Indirect return occurred in " << tvm::PrettyPrint(input_expr); } return expr; } @@ -308,6 +324,63 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr_(const FloorModNode* op) final; private: + // Preprocessing common to both FloorDiv and FloorMod + IterSumExpr PreprocessDividend(IterMapExpr dividend); + + // Create an iterator that represents the expression (split+base), with + // padding such that the iterator's extents are evenly divisible by + // `divisor`. + // + // If iterators can have padding added through UpdatePadding, pad a + // dividend out to be evenly divisible. Otherwise, validate that the + // padding previously defined for the split using UpdatePadding can be + // used. If no such previous padding exists, return an empty + // IterMark. + // + // Returns a pair of IterSplit that represents (split+base) in a + // form that can be dividied by divisors, and PrimExpr that + // represents the left padding applied to split. + std::pair PadDividendToDivisor(IterSplitExpr split, PrimExpr base, + PrimExpr divisor); + + friend struct ErrorLogger; + + /* \brief Utility class for logging errors. + * + * It is not an error for IterMapRewriter to receive an expression that + * cannot be represented as an IterSumExpr. In these cases, + * IterMapRewriter returns the unrepresentable portions of the TIR graph + * without modification. As a result, the usual ICHECK or LOG(FATAL) + * macros cannot be used. Instead, ErrorLogger(this) can be used to + * report an unrepresentable TIR graph, which may be used in error + * messages at the calling scope. + */ + class ErrorLogger { + public: + explicit ErrorLogger(IterMapRewriter* rewriter) : rewriter(rewriter) {} + ~ErrorLogger() { rewriter->errors_.push_back(os.str()); } + + template + ErrorLogger& operator<<(T&& t) { + os << std::forward(t); + return *this; + } + + private: + IterMapRewriter* rewriter; + std::ostringstream os; + }; + + struct IterPaddingInfo { + // Used and collected during first pass + std::vector divisors; + + // Defined on first encounter in second pass + IterSplitExpr padded; + PrimExpr left_pad; + PrimExpr right_pad; + }; + // temp hash for de-duplication purposes. struct IterSumHash { size_t operator()(const IterSumExpr& value) const { @@ -344,12 +417,61 @@ class IterMapRewriter : public ExprMutator { // Internal analyzer Analyzer* analyzer_; - // Counter to keep track of unresolved cases. - int unresolved_count_{0}; + // Error messages for each unresolved expression. + Array& errors_; // The var map std::unordered_map var_map_; // input iter marks std::vector input_marks_; + + // Map from a normal PrimExpr to the padded iterator information for + // it. This is necessary for introducing the same padding in all + // usage of an input iterator. (e.g. (i-1) occurring in the + // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be + // left-padded by 31 for each occurrence.) + std::unordered_map padded_iter_map_; + + /* If allow_padding_ is true, allow the extents of the IterMap to be + * padded beyond the original iterators. + * + * For example, if allow_padding_ is true, the expressions i//4 and + * i%4, where i is on the range [0,18), would be represented as + * IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4). + * This representation would be forbidden if allow_padding_ is false, + * because lower_factor=4 does not evenly divide the original extent of + * 18. + */ + bool update_iterator_padding_{false}; + + /* A boolean expression that is true if any padding has been introduced + * by the transformation, and false otherwise. + * + * Example: [i//4, i%4], i in range [0,16) + * requires_padding_ will be false + * + * Example: [i//4, i%4], i in range [0,18) + * requires_padding_ will be true + * + * Example: [i//4, i%4], i in range [0,N) + * requires_padding_ will be the expression N%4==0 + */ + PrimExpr requires_padding_; + + /* A boolean expression that is true for any padding that has been + * introduced, and false otherwise. If allow_padding_ is false, + * padding_predicate_ will always be false. + * + * Example: [i//4, i%4], i in range [0,16) + * padding_predicate_ will be false + * + * Example: [i//4, i%4], i in range [0,18) + * padding_predicate_ will be `(i//4 == 3) && (i%4 >= 2)` + * + * Example: [i//4, i%4], i in range [0,N) + * padding_predicate_ will be `(N%4!=0) && (i//4 == (N+3)//4-1) && (i%4 >= N%4)` + */ + PrimExpr padding_predicate_; + // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly // an extra offset) // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) @@ -428,8 +550,9 @@ class IterMapRewriter : public ExprMutator { size_t j = 0; for (; j < splits.size(); ++j) { if (used[j]) continue; - if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) + if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) { break; + } } if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective @@ -446,6 +569,7 @@ class IterMapRewriter : public ExprMutator { return Array(); } } + used[j] = true; iters.push_back(splits[j]); expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; @@ -456,9 +580,14 @@ class IterMapRewriter : public ExprMutator { // Case 2. bijective is not required. // We check the extent we calculate is a factor of the extent of the mark // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. - if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) || - (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) { - return Array(); + if (require_bijective) { + if (!analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) { + return Array(); + } + } else { + if (!CanProveDivisible(mark->extent, expected_lower_factor)) { + return Array(); + } } return Array(iters.rbegin(), iters.rend()); } @@ -520,7 +649,7 @@ class IterMapRewriter : public ExprMutator { expr.CopyOnWrite()->base = base + iter_min; return expr; } - unresolved_count_++; + ErrorLogger(this) << "Could not normalize iterators using the constraints given."; return expr; } @@ -536,7 +665,7 @@ class IterMapRewriter : public ExprMutator { if (opt.defined()) { return opt.value(); } else { - unresolved_count_++; + ErrorLogger(this) << "Could not normalize iterators"; return expr; } } @@ -681,15 +810,10 @@ class IterMapRewriter : public ExprMutator { } } - bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { - const auto* clhs = lhs.as(); - const auto* crhs = rhs.as(); - if (clhs && crhs) return clhs->value % crhs->value == 0; - return analyzer_->CanProveEqual(lhs, rhs) || analyzer_->CanProve(floormod(lhs, rhs) == 0); - } + bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs); - PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr& orig); - PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr& orig); + PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); + PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { tir::ExprDeepEqual equal; @@ -894,15 +1018,37 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer, bool simplify_trivial_iterators) { + auto padded_result = DetectPaddedIterMap(indices, input_iters, predicate, require_bijective, + analyzer, simplify_trivial_iterators); + if (padded_result.errors.size()) { + return Array(); + } + if (!analyzer->CanProve(!padded_result.requires_padding)) { + return Array(); + } + return padded_result.indices; +} + +PaddedIterMapResult DetectPaddedIterMap(const Array& indices, + const Map& input_iters, + const PrimExpr& predicate, bool require_bijective, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators) { + PaddedIterMapResult result; + // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. - if (!IterRangeSanityCheck(input_iters)) return Array(); + if (!IterRangeSanityCheck(input_iters)) { + result.errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); + return result; + } Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { - return Array(); + result.errors.push_back("Could not parse predicate as constraints on the input iterators."); + return result; } // We have to make sure when we visit an iterator, all the constraints related with its successors // in the iter var graph has been visited, where the expression of this iterator will contain the @@ -915,30 +1061,51 @@ Array DetectIterMap(const Array& indices, const Map(); + if (result.errors.size()) { + return result; + } } if (!rewriter.CheckConstraints()) { - return Array(); + result.errors.push_back("Invalid constraints."); + return result; } - // Step0.1: rewrite indices - Array results; + + // Step0.1: Check each index to determine required padding + bool allow_padding = !require_bijective; + if (allow_padding) { + for (PrimExpr value : indices) { + rewriter.UpdatePadding(value); + } + } + + // Step0.2: rewrite indices for (PrimExpr value : indices) { - results.push_back(rewriter.Rewrite(value)); - if (rewriter.unresolved_count() != 0) { - return Array(); + result.indices.push_back(rewriter.Rewrite(value)); + if (result.errors.size()) { + return result; } } + + result.requires_padding = rewriter.requires_padding(); + result.padding_predicate = rewriter.padding_predicate(); + // Step1: IterIndependenceChecker checks if the iterator are independent. - if (!rewriter.CheckMapping(results, require_bijective)) { - return Array(); + if (!rewriter.CheckMapping(result.indices, require_bijective)) { + if (require_bijective) { + result.errors.push_back("Index mapping does not form a bijective transform."); + } else { + result.errors.push_back("Mapped indices are not independent."); + } + return result; } - return results; + return result; } TVM_REGISTER_GLOBAL("arith.DetectIterMap") @@ -1050,7 +1217,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, " + << "occurs in " << tvm::PrettyPrint(GetRef(op)); return GetRef(op); } @@ -1070,46 +1238,232 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { } } -PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, - const PrimExpr& orig) { - // floordiv(x*scale, rhs) - if (is_one(rhs)) return std::move(lhs); +IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend) { + if (dividend->IsInstance()) { + auto split = Downcast(dividend); + return IterSumExpr({split}, make_zero(split.dtype())); + } else if (dividend->IsInstance()) { + auto opt_fused = TryFuseIters(Downcast(dividend)); + if (!opt_fused) { + ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(dividend) + << ", can't be written as a single fused IterSum"; + return IterSumExpr(); + } + + IterSumExpr fused = opt_fused.value(); + + ICHECK_EQ(fused->args.size(), 1U); + return fused; + } else { + LOG(FATAL) << "Unsupported subclass of IterMarkExpr"; + return IterSumExpr(); + } +} + +std::pair IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, + PrimExpr base, + PrimExpr divisor) { + // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor + // If FloorMod: (((source//lower_factor) % extent) + base) % divisor + + PrimExpr lookup_key = split; + + auto modified_divisor = [&]() { + if (update_iterator_padding_) { + return divisor; + } + + auto it = padded_iter_map_.find(lookup_key); + if (it == padded_iter_map_.end()) { + return divisor; + } + + const std::vector& divisors = it->second.divisors; + PrimExpr largest_divisor = divisor; + for (const auto& other : divisors) { + if (CanProveDivisible(other, largest_divisor)) { + // New one is bigger, use it + largest_divisor = other; + } else if (CanProveDivisible(largest_divisor, other)) { + // Current is bigger, keep it + } else { + ErrorLogger(this) << "Iterator appears in multiple terms with incompatible divisors " + << tvm::PrettyPrint(largest_divisor) << " and " + << tvm::PrettyPrint(other); + } + } + return largest_divisor; + }(); + + divisor = modified_divisor; + + // First, adding any padding that is on the lower side of a + // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0 + // when iter==0. + + PrimExpr left_pad; + + if (is_zero(base)) { + // Padding on the left is unnecessary if base is known to be zero. + left_pad = make_zero(base->dtype); + } else { + left_pad = analyzer_->Simplify(floormod(base, divisor)); + } + + // Next, adding any padding that is on the upper side of a + // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, divisor) == 0 + // when iter==extent. + + PrimExpr right_edge = left_pad + split->extent; + PrimExpr right_pad; + + if (CanProveDivisible(right_edge, divisor)) { + // Padding on the right is unnecessary if the extent is a multiple of + // the divisor. + right_pad = 0; + } else { + right_pad = analyzer_->Simplify(floormod(-right_edge, divisor)); + } + + if (is_zero(left_pad) && is_zero(right_pad)) { + return {split, left_pad}; + } + + if (update_iterator_padding_) { + // In the first pass, the primary goal is to collect all the divisors + // that may be used for padding. These will impact the divisor used + // to determine padding in the second pass. + IterPaddingInfo& info = padded_iter_map_[lookup_key]; + + info.divisors.push_back(divisor); + + PrimExpr padded_extent = left_pad + split->extent + right_pad; + + IterSumExpr as_sum({split}, left_pad); + IterMark mark(as_sum, padded_extent); + IterSplitExpr new_split(mark); + + return {new_split, left_pad}; + } + + // Any padding that is required during parsing should have been found + // during the first pass that determines the GCD. + auto it = padded_iter_map_.find(lookup_key); + if (it == padded_iter_map_.end()) { + ErrorLogger(this) << "Dividend has extent " << tvm::PrettyPrint(split->extent) << " and offset " + << tvm::PrettyPrint(base) << ", which requires padding for divisor " + << tvm::PrettyPrint(divisor) << "."; + return {IterSplitExpr(), left_pad}; + } + IterPaddingInfo& info = it->second; + + if (info.padded.defined()) { + // A previous visit already applied padding to this iterator. + // (e.g. Visiting `(i+1)//4`, then visiting `(i+1)%4`). + ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad)); + ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad)); + + return {info.padded, left_pad}; + } + + // This is the first encounter with the iterator during the second pass. + IterSumExpr as_sum({split}, left_pad); + IterMark mark(as_sum, left_pad + split->extent + right_pad); + info.padded = IterSplitExpr(mark); + info.left_pad = left_pad; + info.right_pad = right_pad; + + auto left_padding_introduced = (left_pad != 0); + // Equivalent to (0 <= split < left_pad), but easier to simplify in + // terms of the transformed variables. + auto left_padding_predicate = + left_padding_introduced && (floordiv(info.padded, divisor) == floordiv(base, divisor) && + floormod(info.padded, divisor) < left_pad); + + PrimExpr nparts = ceildiv(right_edge, divisor); + + auto right_padding_introduced = (right_pad != 0); + + // Equivalent to (right_edge <= split < right_edge+right_pad), but + // easier to simplify in terms of the transformed variables. + auto right_padding_predicate = right_padding_introduced && + (floordiv(info.padded, divisor) == floordiv(right_edge, divisor) && + floormod(info.padded, divisor) >= floormod(right_edge, divisor)); + + requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); + padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); + + return {info.padded, left_pad}; +} + +PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { + // (lhs + base) // rhs + + if (is_one(rhs)) { + if (is_zero(base)) { + // floordiv(x, 1) = x + return std::move(lhs); + } else { + // floordiv(x+y, 1) = x+y + return IterSumExpr({lhs}, base); + } + } + if (!is_one(lhs->scale)) { - if (CanProveDivisible(lhs->scale, rhs)) { + if (CanProveDivisible(lhs->scale, rhs) && is_zero(base)) { // floordiv(x*c1*c2, c2) = x*c1, c1=scale/rhs lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); return std::move(lhs); + } else if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { + // floordiv(x*c1*c2 + y*c2, c2) = x*c1 + y, c1=scale/rhs + lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); + return IterSumExpr({lhs}, floordiv(base, rhs)); + } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { + // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale + rhs = floordiv(rhs, lhs->scale); + lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); + } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { + // floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale + base = floordiv(base, lhs->scale); + rhs = floordiv(rhs, lhs->scale); + lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); } else { - if (CanProveDivisible(rhs, lhs->scale)) { - // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale - rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); - } else { - // mark as unresolved. - unresolved_count_++; - return orig; - } + // mark as unresolved. + ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " + << tvm::PrettyPrint(lhs->scale) << " and the divisor " + << tvm::PrettyPrint(rhs) + << " cannot be simplified to remove the scaling factor."; + return PrimExpr(); } } // We handle scale!=1 in above code, hence we only consider floordiv(x, rhs) below - // where x=floormod(floordiv(iter, lower_factor), extent) - if (CanProveDivisible(lhs->extent, rhs)) { - // floordiv(floormod(floordiv(iter, lower_factor), c1c2), c1) - // = floordiv(floormod(y, c1c2), c1), where y=floordiv(iter, lower_factor) - // = floordiv(floormod(sc1c2+tc1+u, c1c2), c1), where y=sc1c2+tc1+u, tlower_factor *= rhs; - ptr_lhs->extent = analyzer_->Simplify(floordiv(ptr_lhs->extent, rhs)); - return std::move(lhs); + // where x=floormod(floordiv(iter, lower_factor), extent) + base + + auto pair = PadDividendToDivisor(lhs, base, rhs); + IterSplitExpr padded = pair.first; + PrimExpr left_pad = pair.second; + if (!padded.defined()) { + return PrimExpr(); + } + + // floordiv(floormod(floordiv(iter, lower_factor), c1c2), c1) + // = floordiv(floormod(y, c1c2), c1), where y=floordiv(iter, lower_factor) + // = floordiv(floormod(sc1c2+tc1+u, c1c2), c1), where y=sc1c2+tc1+u, tsource, + /* lower_factor = */ padded->lower_factor * rhs, + /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), + /* scale = */ padded->scale); + + auto new_base = floordiv(base - left_pad, rhs); + if (is_zero(new_base)) { + return std::move(new_split); } else { - // mark as unresolved. - unresolved_count_++; - return orig; + return IterSumExpr({new_split}, new_base); } } @@ -1136,62 +1490,66 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance()) { // cannot divide an iterator, mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " << GetRef(op) + << " may not be an iterator"; return GetRef(op); } - if (a->IsInstance()) { - IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - IterSumExpr sum = opt.value(); - if (!is_zero(sum->base)) { - unresolved_count_++; - return GetRef(op); - } - ICHECK_EQ(sum->args.size(), 1U); - return SplitFloorDivConst(sum->args[0], b, GetRef(op)); - } else { - unresolved_count_++; - return GetRef(op); - } - } else { - ICHECK(a->IsInstance()); - IterSplitExpr ret = Downcast(std::move(a)); - return SplitFloorDivConst(ret, b, GetRef(op)); + IterSumExpr preprocessed = PreprocessDividend(Downcast(a)); + if (!preprocessed.defined()) { + return GetRef(op); + } + PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); + if (!remainder.defined()) { + return GetRef(op); } + return remainder; } -PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, - const PrimExpr& orig) { - // floormod(x*scale, rhs) - if (is_one(rhs)) return make_zero(lhs->dtype); +PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { + // (lhs + base) % rhs + + if (is_one(rhs)) { + // floormod(x, 1) = 0 + return make_zero(lhs->dtype); + } + if (!is_one(lhs->scale)) { - // floormod(x*c1*c2, c1) = 0 - if (CanProveDivisible(lhs->scale, rhs)) { + if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { + // floormod(x*c1*c2, c1) = 0 return make_zero(lhs->dtype); + } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { + // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale + rhs = floordiv(rhs, lhs->scale); + } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { + // floormod(x*c1 + y*c1, c1*c2) = (floormod(x+y, c2)) * c1, where c2 = rhs/scale + rhs = floordiv(rhs, lhs->scale); + base = floordiv(base, lhs->scale); } else { - if (CanProveDivisible(rhs, lhs->scale)) { - // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale - rhs = floordiv(rhs, lhs->scale); - } else { - // mark as unresolved. - unresolved_count_++; - return orig; - } + // mark as unresolved. + ErrorLogger(this) + << "Cannot represent as IterMap: the left-hand side of FloorMod has a scaling factor, " + << tvm::PrettyPrint(lhs->scale) << " and the right-hand " << tvm::PrettyPrint(rhs) + << " cannot be used to simplify out the scaling factor."; + return PrimExpr(); } } - // floormod(x, rhs) where x=floormod(floordiv(iter, lower_factor), extent) - if (CanProveDivisible(lhs->extent, rhs)) { - // floormod(floormod(floordiv(iter, lower_factor), c1c2), c1) - // = floormod(floordiv(iter, lower_factor), c1), where c1=rhs - lhs.CopyOnWrite()->extent = rhs; - return std::move(lhs); - } else { - // mark as unresolved. - unresolved_count_++; - return orig; + // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below + // where x=floormod(floordiv(iter, lower_factor), extent) + base + + auto pair = PadDividendToDivisor(lhs, base, rhs); + IterSplitExpr padded = pair.first; + if (!padded.defined()) { + return PrimExpr(); } + + // floormod(floormod(floordiv(iter, lower_factor), c1c2), c1) + // = floormod(floordiv(iter, lower_factor), c1), where c1=rhs + return IterSplitExpr(padded->source, + /* lower_factor = */ padded->lower_factor, + /* extent = */ rhs, + /* scale = */ padded->scale); } PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { @@ -1217,36 +1575,30 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance()) { // cannot mod an iterator, mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the right-hand side of FloorMod in " + << GetRef(op) << " may not be an iterator"; return GetRef(op); } - if (a->IsInstance()) { - IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - IterSumExpr sum = opt.value(); - if (!is_zero(sum->base)) { - unresolved_count_++; - return GetRef(op); - } - return SplitFloorModConst(sum->args[0], b, GetRef(op)); - } else { - unresolved_count_++; - return GetRef(op); - } - } else { - ICHECK(a->IsInstance()); - IterSplitExpr ret = Downcast(std::move(a)); - return SplitFloorModConst(ret, b, GetRef(op)); + IterSumExpr preprocessed = PreprocessDividend(Downcast(a)); + if (!preprocessed.defined()) { + return GetRef(op); } + + PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); + if (!remainder.defined()) { + return GetRef(op); + } + return remainder; } -/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */ +/*! * \brief Given an expression that may contain IterVarMapExpr, transform it to normal PrimExpr. + */ class IterMapToExprNormalizer : public ExprMutator { public: explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} - PrimExpr Convert(const IterMapExpr& expr) { return VisitExpr(expr); } + PrimExpr Convert(const PrimExpr& expr) { return VisitExpr(expr); } private: /*! \brief Override VisitExpr for iter expr type processing */ @@ -1292,15 +1644,28 @@ class IterMapToExprNormalizer : public ExprMutator { Analyzer* analyzer_; }; -PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr) { +bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { + const auto* clhs = lhs.as(); + const auto* crhs = rhs.as(); + if (clhs && crhs) { + return clhs->value % crhs->value == 0; + } + + IterMapToExprNormalizer normalizer(analyzer_); + PrimExpr dividend = normalizer.Convert(lhs); + PrimExpr divisor = normalizer.Convert(rhs); + + return analyzer_->CanProveEqual(dividend, divisor) || + analyzer_->CanProve(floormod(dividend, divisor) == 0); +} + +PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { arith::Analyzer analyzer; IterMapToExprNormalizer normalizer(&analyzer); return normalizer.Convert(expr); } -TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const IterMapExpr& expr) { - return NormalizeIterMapToExpr(expr); -}); +TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, const PrimExpr& input_pred, bool require_bijective) { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 0f7aa4c8a978..4d8b6ff769cf 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -448,6 +448,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y); TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y); + // Two representations of const*ceildiv(x, c1) + TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c2), c1) * c1, x - floormod(x, c2), + c1.Eval()->value == -c2.Eval()->value); + // canonicalization TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 93f308b42d74..4c0a7d3508c1 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -50,6 +50,70 @@ IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(A return IndexMap(initial_indices, func(initial_indices)); } +std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges) const { + // Dummy variables to represent the inverse's inputs. + Array output_vars; + for (size_t i = 0; i < (*this)->final_indices.size(); i++) { + PrimExpr index = (*this)->final_indices[i]; + // TODO(Lunderberg): Better names for these variables. A variable + // that is passed through unmodified (`index` is an element of + // `initial_indices`) should use that input index's name. A pair + // of output indices variables split from a single input index + // should be named (X.outer,X.inner). + std::stringstream ss; + ss << "axis" << i; + Var var_index(ss.str(), index.dtype()); + output_vars.push_back(var_index); + } + + // Dummy ranges for the extent of each input. + Map input_iters; + ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size()); + for (size_t i = 0; i < initial_ranges.size(); i++) { + input_iters.Set((*this)->initial_indices[i], initial_ranges[i]); + } + + // Unpack the output indices into linear combinations of the initial + // indices. + arith::Analyzer analyzer; + auto padded_iter_map = + DetectPaddedIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /* require_bijective = */ false, &analyzer, + /* simplify_trivial_iterators = */ false); + CHECK(padded_iter_map.errors.empty()) << "Could not parse mapping as sum of iterators. " + << "Error: " << padded_iter_map.errors[0]; + + // Determine expressions for the input variables, in terms of the + // output variables. + Map inverse_exprs_map = InverseAffineIterMap( + padded_iter_map.indices, Array(output_vars.begin(), output_vars.end())); + + // Unpack the map to an array, maintaining the same parameter order. + Array inverse_exprs; + for (const auto& index : (*this)->initial_indices) { + inverse_exprs.push_back(inverse_exprs_map.at(index)); + } + + PrimExpr padding_predicate = padded_iter_map.padding_predicate; + padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); + padding_predicate = Substitute(padding_predicate, inverse_exprs_map); + + { + auto output_ranges = (*this)->MapRanges(initial_ranges); + ICHECK_EQ(output_ranges.size(), output_vars.size()); + + arith::Analyzer analyzer; + for (size_t i = 0; i < output_vars.size(); ++i) { + analyzer.Bind(output_vars[i], output_ranges[i]); + } + + // Additional simplification steps required to unwrap nested floordiv/floormod + padding_predicate = analyzer.Simplify(padding_predicate, 10); + } + + return {IndexMap(output_vars, inverse_exprs), padding_predicate}; +} + IndexMap IndexMap::Inverse(Array initial_ranges) const { // Dummy variables to represent the inverse's inputs. Array output_vars; @@ -202,6 +266,14 @@ TVM_REGISTER_GLOBAL("tir.IndexMap") }); TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method(&IndexMapNode::MapIndices); +TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_method(&IndexMapNode::MapShape); +TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse); + +TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") + .set_body_typed([](IndexMap forward, Array initial_ranges) { + auto result = forward.NonSurjectiveInverse(initial_ranges); + return Array{result.first, result.second}; + }); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index eb43157d805a..3c11d2485332 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -109,6 +109,7 @@ bool ProducerCoversConsumer(const Array& buffer_shape, analyzer->canonical_simplify(consumed_region[i].max())); produced = arith::Intersect({produced, buffer_size}); consumed = arith::Intersect({consumed, buffer_size}); + if (!analyzer->CanProve((analyzer->canonical_simplify(produced.min() - consumed.min()) <= 0) && (analyzer->canonical_simplify(consumed.max() - produced.max()) <= 0))) { return false; diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py new file mode 100644 index 000000000000..a8f5204f0202 --- /dev/null +++ b/tests/python/unittest/test_index_map.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm.tir import IndexMap +from tvm.ir import assert_structural_equal + + +def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: + + iters_1 = map1.map_indices(map2.initial_indices) + iters_2 = map2.final_indices + assert len(iters_1) == len(iters_2) + + analyzer = tvm.arith.Analyzer() + for iter1, iter2 in zip(iters_1, iters_2): + assert analyzer.can_prove_equal(iter1, iter2) + + +def test_index_mapping(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + assert_structural_equal(index_map.map_indices([0]), [0, 0]) + assert_structural_equal(index_map.map_indices([3]), [0, 3]) + assert_structural_equal(index_map.map_indices([4]), [1, 0]) + assert_structural_equal(index_map.map_indices([42]), [10, 2]) + + +def test_shape_mapping(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + assert_structural_equal(index_map.map_shape([4]), [1, 4]) + assert_structural_equal(index_map.map_shape([16]), [4, 4]) + + assert_structural_equal(index_map.map_shape([14]), [4, 4]) + + +def test_inverse(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + expected_inverse = IndexMap.from_func(lambda i, j: [4 * i + j]) + + assert index_map.inverse([16]).is_equivalent_to(expected_inverse) + + +def test_nonbijective_inverse_gives_error(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + with pytest.raises(tvm.TVMError): + index_map.inverse([14]) + + +dynamic_N = tvm.tir.Var("N", "int32") +padding_test_case = tvm.testing.parameter( + by_dict={ + "no_padding": dict( + forward=lambda i: [i // 4, i % 4], + inverse=lambda i, j: [4 * i + j], + pre_shape=[16], + post_shape=[4, 4], + padding=lambda i, j: tvm.runtime.convert(False), + ), + "right_padding": dict( + forward=lambda i: [i // 4, i % 4], + inverse=lambda i, j: [4 * i + j], + pre_shape=[15], + post_shape=[4, 4], + padding=lambda i, j: tvm.tir.And(i == 3, j >= 3), + ), + "left_padding": dict( + forward=lambda i: [(i + 1) // 4, (i + 1) % 4], + inverse=lambda i, j: [4 * i + j - 1], + pre_shape=[15], + post_shape=[4, 4], + padding=lambda i, j: tvm.tir.And(i == 0, j < 1), + ), + "left_and_right_padding": dict( + forward=lambda i: [(i + 1) // 4, (i + 1) % 4], + inverse=lambda i, j: [4 * i + j - 1], + pre_shape=[14], + post_shape=[4, 4], + padding=lambda i, j: tvm.tir.Or( + tvm.tir.And(i == 0, j < 1), + tvm.tir.And(i == 3, j >= 3), + ), + ), + "dynamic_size": dict( + forward=lambda i: [i // 4, i % 4], + inverse=lambda i, j: [4 * i + j], + pre_shape=[dynamic_N], + post_shape=[(dynamic_N - 1) // 4 + 1, 4], + padding=lambda i, j: tvm.tir.And( + dynamic_N % (-4) != 0, + tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4), + ), + ), + "2d_padding": dict( + forward=lambda i, j: [(i + 1) // 4, (j + 5) // 8, (i + 1) % 4, (j + 5) % 8], + inverse=lambda i_outer, j_outer, i_inner, j_inner: [ + 4 * i_outer + i_inner - 1, + 8 * j_outer + j_inner - 5, + ], + pre_shape=[14, 31], + post_shape=[ + 4, # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 + 5, # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 + 4, # Range of iter%4 + 8, # Range of iter%8 + ], + padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( + tvm.tir.Or( + tvm.tir.And(i_outer == 0, i_inner < 1), + tvm.tir.And(i_outer == 3, i_inner >= 3), + ), + tvm.tir.Or( + tvm.tir.And(j_outer == 0, j_inner < 5), + tvm.tir.And(j_outer == 4, j_inner >= 4), + ), + ), + ), + "multiple_right_padding": dict( + forward=lambda i: [i // 32, (i // 4) % 8, i % 4], + inverse=lambda i, j, k: [32 * i + 4 * j + k], + pre_shape=[116], + post_shape=[4, 8, 4], + padding=lambda i, j, k: tvm.tir.And(i == 3, 4 * j + k >= 20), + ), + "multiple_right_padding_transpose": dict( + forward=lambda i: [(i // 4) % 8, i // 32, i % 4], + inverse=lambda j, i, k: [32 * i + 4 * j + k], + pre_shape=[116], + post_shape=[8, 4, 4], + padding=lambda j, i, k: tvm.tir.And(i == 3, 4 * j + k >= 20), + ), + "multiple_left_padding": dict( + forward=lambda i: [(i + 5) // 32, ((i + 5) // 4) % 8, (i + 5) % 4], + inverse=lambda i, j, k: [32 * i + 4 * j + k - 5], + pre_shape=[123], + post_shape=[4, 8, 4], + padding=lambda i, j, k: tvm.tir.And(i == 0, j * 4 + k < 5), + ), + "multiple_left_padding_with_transpose": dict( + forward=lambda i: [((i + 5) // 4) % 8, (i + 5) // 32, (i + 5) % 4], + inverse=lambda j, i, k: [32 * i + 4 * j + k - 5], + pre_shape=[123], + post_shape=[8, 4, 4], + padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5), + ), + } +) + + +def test_nonsurjective_inverse(padding_test_case): + index_map = IndexMap.from_func(padding_test_case["forward"]) + + inverse, padding_predicate = index_map.non_surjective_inverse(padding_test_case["pre_shape"]) + expected_inverse = IndexMap.from_func(padding_test_case["inverse"]) + assert inverse.is_equivalent_to(expected_inverse) + + post_shape = index_map.map_shape(padding_test_case["pre_shape"]) + tvm.ir.assert_structural_equal(post_shape, padding_test_case["post_shape"]) + + expected_predicate = padding_test_case["padding"](*inverse.initial_indices) + + # Can't use analyzer.can_prove_equal, because it can't simplify + # expressions like `(4*i+j >= 14) - (4*i+j >= 14)`. + analyzer = tvm.arith.Analyzer() + expected_predicate = analyzer.simplify(expected_predicate) + padding_predicate = analyzer.simplify(padding_predicate) + tvm.ir.assert_structural_equal(padding_predicate, expected_predicate) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 10371d3ccaf1..19be0b8699ac 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -48,14 +48,6 @@ def _make_loops(loop_vars: List[Var], extents: List[int]) -> List[For]: ] -def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: - iters_1 = map1.map_indices(map2.initial_indices) - iters_2 = map2.final_indices - assert len(iters_1) == len(iters_2) - for iter1, iter2 in zip(iters_1, iters_2): - assert expr_deep_equal(iter1, iter2) - - def test_suggest_index_map_simple(): i, j = _make_vars("i", "j") index_map = suggest_index_map( @@ -78,7 +70,7 @@ def test_suggest_index_map_simple(): floormod(y, 16), ], ) - _assert_equal_index_map(index_map, expected_index_map) + assert index_map.is_equivalent_to(expected_index_map) def test_suggest_index_map_bijective(): @@ -98,7 +90,7 @@ def test_suggest_index_map_bijective(): floordiv(x, 2), ], ) - _assert_equal_index_map(index_map, expected_index_map) + assert index_map.is_equivalent_to(expected_index_map) @tvm.script.ir_module