From b9fc65d365a7d99ad17f7ac33a8f34685f61d4fa Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 28 Apr 2022 14:03:37 -0500 Subject: [PATCH 01/10] [Debug] Error logging in DetectIterMap --- src/arith/iter_affine_map.cc | 126 +++++++++++++++++++++++++---------- 1 file changed, 92 insertions(+), 34 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index ec2680d8e666..9a05894398a6 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -195,7 +195,15 @@ class IterMapRewriter : public ExprMutator { } } - size_t unresolved_count() const { return unresolved_count_; } + size_t unresolved_count() const { return errors_.size(); } + + void print_errors() const { + for (const auto& err : errors_) { + std::cout << "Error: " << err << std::endl; + } + } + + std::vector errors() const { return errors_; } IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); @@ -292,7 +300,9 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr(const PrimExpr& input_expr) final { auto expr = ExprMutator::VisitExpr(input_expr); if (expr->IsInstance()) { - unresolved_count_++; + ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in " + << "IterMapRewriter using DirectMutate. " + << "Indirect return occurred in " << tvm::PrettyPrint(input_expr); } return expr; } @@ -308,6 +318,33 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr_(const FloorModNode* op) final; private: + /* \brief Utility class for logging errors. + * + * It is not an error for IterMapRewriter to receive an expression that + * cannot be represented as an IterSumExpr. In these cases, + * IterMapRewriter returns the unrepresentable portions of the TIR graph + * without modification. As a result, the usual ICHECK or LOG(FATAL) + * macros cannot be used. Instead, ErrorLogger(this) can be used to + * report an unrepresentable TIR graph, which may be used in error + * messages at the calling scope. + */ + friend struct ErrorLogger; + class ErrorLogger { + public: + explicit ErrorLogger(IterMapRewriter* rewriter) : rewriter(rewriter) {} + ~ErrorLogger() { rewriter->errors_.push_back(os.str()); } + + template + ErrorLogger& operator<<(T&& t) { + os << std::forward(t); + return *this; + } + + private: + IterMapRewriter* rewriter; + std::ostringstream os; + }; + // temp hash for de-duplication purposes. struct IterSumHash { size_t operator()(const IterSumExpr& value) const { @@ -344,8 +381,8 @@ class IterMapRewriter : public ExprMutator { // Internal analyzer Analyzer* analyzer_; - // Counter to keep track of unresolved cases. - int unresolved_count_{0}; + // Error messages for each unresolved expression. + std::vector errors_; // The var map std::unordered_map var_map_; // input iter marks @@ -520,7 +557,7 @@ class IterMapRewriter : public ExprMutator { expr.CopyOnWrite()->base = base + iter_min; return expr; } - unresolved_count_++; + ErrorLogger(this) << "Could not normalize iterators using the constraints given."; return expr; } @@ -536,7 +573,7 @@ class IterMapRewriter : public ExprMutator { if (opt.defined()) { return opt.value(); } else { - unresolved_count_++; + ErrorLogger(this) << "Could not normalize iterators"; return expr; } } @@ -897,7 +934,9 @@ Array DetectIterMap(const Array& indices, const Map(); + if (!IterRangeSanityCheck(input_iters)) { + return Array(); + } Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && @@ -920,11 +959,14 @@ Array DetectIterMap(const Array& indices, const Map(); + if (rewriter.unresolved_count() != 0) { + return Array(); + } } if (!rewriter.CheckConstraints()) { return Array(); } + // Step0.1: rewrite indices Array results; for (PrimExpr value : indices) { @@ -1050,7 +1092,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, " + << "occurs in " << tvm::PrettyPrint(GetRef(op)); return GetRef(op); } @@ -1079,16 +1122,17 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, // floordiv(x*c1*c2, c2) = x*c1, c1=scale/rhs lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); return std::move(lhs); + } else if (CanProveDivisible(rhs, lhs->scale)) { + // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale + rhs = floordiv(rhs, lhs->scale); + lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); } else { - if (CanProveDivisible(rhs, lhs->scale)) { - // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale - rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); - } else { - // mark as unresolved. - unresolved_count_++; - return orig; - } + // mark as unresolved. + ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " + << tvm::PrettyPrint(lhs->scale) << " and the divisor " + << tvm::PrettyPrint(rhs) + << " cannot be simplified to remove the scaling factor."; + return orig; } } @@ -1108,7 +1152,9 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as IterMap: the numerator's extent, " + << tvm::PrettyPrint(lhs->extent) << " is not a multiple of the divisor, " + << tvm::PrettyPrint(rhs) << "."; return orig; } } @@ -1136,7 +1182,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance()) { // cannot divide an iterator, mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " << GetRef(op) + << " may not be an iterator"; return GetRef(op); } @@ -1145,13 +1192,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (Optional opt = TryFuseIters(ret)) { IterSumExpr sum = opt.value(); if (!is_zero(sum->base)) { - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the dividend in " + << tvm::PrettyPrint(GetRef(op)) << " has a non-zero offset."; return GetRef(op); } ICHECK_EQ(sum->args.size(), 1U); return SplitFloorDivConst(sum->args[0], b, GetRef(op)); } else { - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the dividend in " + << tvm::PrettyPrint(GetRef(op)) + << " cannot be represented as a single fused iterator"; return GetRef(op); } } else { @@ -1169,15 +1219,16 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, // floormod(x*c1*c2, c1) = 0 if (CanProveDivisible(lhs->scale, rhs)) { return make_zero(lhs->dtype); + } else if (CanProveDivisible(rhs, lhs->scale)) { + // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale + rhs = floordiv(rhs, lhs->scale); } else { - if (CanProveDivisible(rhs, lhs->scale)) { - // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale - rhs = floordiv(rhs, lhs->scale); - } else { - // mark as unresolved. - unresolved_count_++; - return orig; - } + // mark as unresolved. + ErrorLogger(this) + << "Cannot represent as IterMap: the left-hand side of FloorMod has a scaling factor, " + << tvm::PrettyPrint(lhs->scale) << " and the right-hand " << tvm::PrettyPrint(rhs) + << " cannot be used to simplify out the scaling factor."; + return orig; } } @@ -1189,7 +1240,10 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as IterMap: the left-hand side of FloorMod has extent " + << tvm::PrettyPrint(lhs->extent) + << " which does not evenly divide the right-hand side, " + << tvm::PrettyPrint(rhs) << "."; return orig; } } @@ -1217,7 +1271,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance()) { // cannot mod an iterator, mark as unresolved. - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the right-hand side of FloorMod in " + << GetRef(op) << " may not be an iterator"; return GetRef(op); } @@ -1226,12 +1281,15 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (Optional opt = TryFuseIters(ret)) { IterSumExpr sum = opt.value(); if (!is_zero(sum->base)) { - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the left-hand side of FloorMod in " + << tvm::PrettyPrint(GetRef(op)) << " has a non-zero offset."; return GetRef(op); } return SplitFloorModConst(sum->args[0], b, GetRef(op)); } else { - unresolved_count_++; + ErrorLogger(this) << "Cannot represent as an IterMap: the left-hand side of FloorMod in " + << tvm::PrettyPrint(GetRef(op)) + << " cannot be represented as a single fused iterator"; return GetRef(op); } } else { From cfd63f3f058764f242f786f6f4e8d0eb707d5ba2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 28 Apr 2022 13:06:19 -0500 Subject: [PATCH 02/10] [Affine] Allowed PrimExpr argument to NormalizeIterMapToExpr This allows it to be used for any expression containing an `IterMapExpr`, not just expressions whose top-level node is an `IterMapExpr`. --- include/tvm/arith/iter_affine_map.h | 6 +++--- src/arith/iter_affine_map.cc | 11 +++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index f8371b1a6176..b6b80d7c7274 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -352,11 +352,11 @@ Array> SubspaceDivide(const Array& bindings, bool require_bijective, arith::Analyzer* analyzer); /*! - * \brief Given an IterMapExpr, transform it to normal PrimExpr. - * \param expr The input IterMapExpr. + * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. + * \param expr The input expression, which may containg IterMapExpr. * \return The corresponding normal PrimExpr. */ -PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr); +PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr); } // namespace arith } // namespace tvm diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 9a05894398a6..c672530dff53 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1299,12 +1299,13 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { } } -/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */ +/*! * \brief Given an expression that may contain IterVarMapExpr, transform it to normal PrimExpr. + */ class IterMapToExprNormalizer : public ExprMutator { public: explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} - PrimExpr Convert(const IterMapExpr& expr) { return VisitExpr(expr); } + PrimExpr Convert(const PrimExpr& expr) { return VisitExpr(expr); } private: /*! \brief Override VisitExpr for iter expr type processing */ @@ -1350,15 +1351,13 @@ class IterMapToExprNormalizer : public ExprMutator { Analyzer* analyzer_; }; -PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr) { +PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { arith::Analyzer analyzer; IterMapToExprNormalizer normalizer(&analyzer); return normalizer.Convert(expr); } -TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const IterMapExpr& expr) { - return NormalizeIterMapToExpr(expr); -}); +TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, const PrimExpr& input_pred, bool require_bijective) { From d9310b5bf38e520ca8a929b393cc3c9664fa8d69 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 28 Apr 2022 13:54:16 -0500 Subject: [PATCH 03/10] [Affine] Implemented DetectPaddedIterMap The existing DetectIterMap tries to rewrite index expression as a linear combination of split/fused iterators, where the new iterators cover the exact same indices as the original expression. DetectPaddedIterMap relaxes this condition, allowing the new iterators to cover a superset of indices that the initial index expression covered. It uses the minimum amount of padding necessary to represent these transformations, and also a predicate that identifies any padding that has been added. This is a utility function to be used for layout transformations of buffers, in cases where the pre-transformation shape of the buffer does not evenly fit into the post-transformation shape. --- include/tvm/arith/iter_affine_map.h | 69 +++- src/arith/iter_affine_map.cc | 562 +++++++++++++++++++++------- 2 files changed, 502 insertions(+), 129 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index b6b80d7c7274..4cf6f086d1ed 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -285,6 +285,73 @@ class IterSumExpr : public IterMapExpr { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); + +/*! \brief A utility struct for return values from DetectPaddedIterMap + */ +struct PaddedIterMapResult { + // Any errors that occurred while converting the input indices. If + // the array is empty, the conversion was successful. + Array errors; + + // The detected pattern if a match exists. + Array indices; + + /* \brief Boolean expression indicating if padding was required + * + * `requires_padding` evaluates to true if the returned indices + * contain padding relative to the provided expressions, and false + * otherwise. If `input_iters` contains a variable extent, this + * expression may be in terms of those variables. + */ + PrimExpr requires_padding; + + /* \brief Boolean expression indicating if a specific value w + * + * `padding_predicate` evaluates to true for a set of indices that + * are outside the bounds of the provided index iterators, but + * inside the bounds of the returned index iterators. This + * expression is in terms of the variables provided in + * `input_iters`. + */ + PrimExpr padding_predicate; +}; + +/*! + * \brief Detect if indices can be written as + * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] + * + * Here y = some-quasi-affine-iter-map(input_iters) and c are + * symbolic constants. The y_i iterators may be padded to fit this + * representation. + * + * We also requires that y_i and y_j to be independent for i != j. + * + * For returned value rv, the following is always true: + * - rv.indices[i]->args.size() <=1: only one iterator per element. + * + * \param indices The indices to detect pattern for. + * + * \param input_iters Map from variable to iterator's range. + * + * \param predicate The predicate constraints on the input iterators + * + * \param require_bijective A boolean flag that indicates whether the + * mapping should be bijective. If true, no padding may be + * introduced. + * + * \param analyzer Analyzer used to get context information. + * + * \param simplify_trivial_iterators If true, iterators with extent of + * 1 will be replaced with a constant value. + * + * \return An instance of PaddedIterMapResult. + */ +PaddedIterMapResult DetectPaddedIterMap(const Array& indices, + const Map& input_iters, + const PrimExpr& predicate, bool require_bijective, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); + /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * @@ -353,7 +420,7 @@ Array> SubspaceDivide(const Array& bindings, /*! * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. - * \param expr The input expression, which may containg IterMapExpr. + * \param expr The input expression, which may contain IterMapExpr. * \return The corresponding normal PrimExpr. */ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index c672530dff53..829a32cddba7 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "../support/utils.h" #include "const_fold.h" @@ -174,8 +175,11 @@ class IterMapRewriter : public ExprMutator { using Parent = ExprMutator; explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, - bool simplify_trivial_iterators) - : analyzer_(analyzer) { + bool simplify_trivial_iterators, Array* errors) + : analyzer_(analyzer), + errors_(*errors), + requires_padding_(const_false()), + padding_predicate_(const_false()) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -195,20 +199,19 @@ class IterMapRewriter : public ExprMutator { } } - size_t unresolved_count() const { return errors_.size(); } - - void print_errors() const { - for (const auto& err : errors_) { - std::cout << "Error: " << err << std::endl; - } - } - - std::vector errors() const { return errors_; } + PrimExpr padding_predicate() const { return padding_predicate_; } + PrimExpr requires_padding() const { return requires_padding_; } IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } + void UpdatePadding(const PrimExpr& expr) { + update_iterator_padding_ = true; + DirectMutate(expr); + update_iterator_padding_ = false; + } + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const Optional& predicate_induced_min, const Optional& predicate_induced_max) { @@ -242,6 +245,7 @@ class IterMapRewriter : public ExprMutator { // All the splits that refers to the iter_mark covers its extent. // The splits do not overlap with each other. collector.Collect(bindings); + for (const IterMark& mark : collector.visited_) { if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { return false; @@ -318,6 +322,22 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr_(const FloorModNode* op) final; private: + // Preprocessing common to both FloorDiv and FloorMod + IterSumExpr PreprocessDividend(IterMapExpr dividend); + + // Create an iterator that represents the expression (split+base), with + // padding such that the iterator's extents are evenly divisible by + // `divisor`. + // + // If iterators can have padding added through UpdatePadding, pad a + // dividend out to be evenly divisible. Otherwise, validate that the + // padding previously defined for the split using UpdatePadding can be + // used. If no such previous padding exists, return an empty + // IterMark. + IterSplitExpr PadDividendToDivisor(IterSplitExpr split, PrimExpr base, PrimExpr divisor); + + friend struct ErrorLogger; + /* \brief Utility class for logging errors. * * It is not an error for IterMapRewriter to receive an expression that @@ -328,7 +348,6 @@ class IterMapRewriter : public ExprMutator { * report an unrepresentable TIR graph, which may be used in error * messages at the calling scope. */ - friend struct ErrorLogger; class ErrorLogger { public: explicit ErrorLogger(IterMapRewriter* rewriter) : rewriter(rewriter) {} @@ -345,6 +364,19 @@ class IterMapRewriter : public ExprMutator { std::ostringstream os; }; + struct IterPaddingInfo { + IterPaddingInfo() : var_extent("extent") {} + // Used and collected during first pass + IterSplitExpr unpadded; + Var var_extent; + std::vector divisors; + + // Defined on first encounter in second pass + IterSplitExpr padded; + PrimExpr left_pad; + PrimExpr right_pad; + }; + // temp hash for de-duplication purposes. struct IterSumHash { size_t operator()(const IterSumExpr& value) const { @@ -382,11 +414,57 @@ class IterMapRewriter : public ExprMutator { // Internal analyzer Analyzer* analyzer_; // Error messages for each unresolved expression. - std::vector errors_; + Array& errors_; // The var map std::unordered_map var_map_; // input iter marks std::vector input_marks_; + + // Map from an IterSumExpr containing padding to the IterMark + // representing the padded version. + std::unordered_map padded_iter_map_; + + /* If allow_padding_ is true, allow the extents of the IterMap to be + * padded beyond the original iterators. + * + * For example, if allow_padding_ is true, the expressions i//4 and + * i%4, where i is on the range [0,18), would be represented as + * IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4). + * This representation would be forbidden if allow_padding_ is false, + * because lower_factor=4 does not evenly divide the original extent of + * 18. + */ + bool update_iterator_padding_{false}; + + /* A boolean expression that is true if any padding has been introduced + * by the transformation, and false otherwise. + * + * Example: [i//4, i%4], i in range [0,16) + * requires_padding_ will be false + * + * Example: [i//4, i%4], i in range [0,18) + * requires_padding_ will be true + * + * Example: [i//4, i%4], i in range [0,N) + * requires_padding_ will be the expression N%4==0 + */ + PrimExpr requires_padding_; + + /* A boolean expression that is true for any padding that has been + * introduced, and false otherwise. If allow_padding_ is false, + * padding_predicate_ will always be false. + * + * Example: [i//4, i%4], i in range [0,16) + * padding_predicate_ will be false + * + * Example: [i//4, i%4], i in range [0,18) + * padding_predicate_ will be `(i//4 == 3) && (i%4 >= 2)` + * + * Example: [i//4, i%4], i in range [0,N) + * padding_predicate_ will be `(N%4!=0) && (i//4 == (N+3)//4-1) && (i%4 >= N%4)` + */ + PrimExpr padding_predicate_; + // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly // an extra offset) // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) @@ -465,8 +543,9 @@ class IterMapRewriter : public ExprMutator { size_t j = 0; for (; j < splits.size(); ++j) { if (used[j]) continue; - if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) + if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) { break; + } } if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective @@ -483,6 +562,7 @@ class IterMapRewriter : public ExprMutator { return Array(); } } + used[j] = true; iters.push_back(splits[j]); expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; @@ -493,9 +573,23 @@ class IterMapRewriter : public ExprMutator { // Case 2. bijective is not required. // We check the extent we calculate is a factor of the extent of the mark // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. - if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) || - (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) { - return Array(); + if (require_bijective) { + if (!analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) { + return Array(); + } + } else { + // This still requires that the splits can be part of a + // bijective transformation, even if they are not sufficient to + // form a bijective transformation. + // + // TODO: Why is this condition used, and what is it intended to allow? + // + // if (!CanProveDivisible(mark->extent, expected_lower_factor)) { + // std::cout << "Expected resulting lower factor " << expected_lower_factor + // << " to be divisible by the extent of the parent mark " << mark->extent + // << std::endl; + // return Array(); + // } } return Array(iters.rbegin(), iters.rend()); } @@ -718,15 +812,10 @@ class IterMapRewriter : public ExprMutator { } } - bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { - const auto* clhs = lhs.as(); - const auto* crhs = rhs.as(); - if (clhs && crhs) return clhs->value % crhs->value == 0; - return analyzer_->CanProveEqual(lhs, rhs) || analyzer_->CanProve(floormod(lhs, rhs) == 0); - } + bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs); - PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr& orig); - PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr& orig); + PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); + PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { tir::ExprDeepEqual equal; @@ -931,17 +1020,37 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer, bool simplify_trivial_iterators) { + auto padded_result = DetectPaddedIterMap(indices, input_iters, predicate, require_bijective, + analyzer, simplify_trivial_iterators); + if (padded_result.errors.size()) { + return Array(); + } + if (!analyzer->CanProve(!padded_result.requires_padding)) { + return Array(); + } + return padded_result.indices; +} + +PaddedIterMapResult DetectPaddedIterMap(const Array& indices, + const Map& input_iters, + const PrimExpr& predicate, bool require_bijective, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators) { + PaddedIterMapResult result; + // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. if (!IterRangeSanityCheck(input_iters)) { - return Array(); + result.errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); + return result; } Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { - return Array(); + result.errors.push_back("Could not parse predicate as constraints on the input iterators."); + return result; } // We have to make sure when we visit an iterator, all the constraints related with its successors // in the iter var graph has been visited, where the expression of this iterator will contain the @@ -954,33 +1063,51 @@ Array DetectIterMap(const Array& indices, const Map(); + if (result.errors.size()) { + return result; } } if (!rewriter.CheckConstraints()) { - return Array(); + result.errors.push_back("Invalid constraints."); + return result; } - // Step0.1: rewrite indices - Array results; + // Step0.1: Check each index to determine required padding + bool allow_padding = !require_bijective; + if (allow_padding) { + for (PrimExpr value : indices) { + rewriter.UpdatePadding(value); + } + } + + // Step0.2: rewrite indices for (PrimExpr value : indices) { - results.push_back(rewriter.Rewrite(value)); - if (rewriter.unresolved_count() != 0) { - return Array(); + result.indices.push_back(rewriter.Rewrite(value)); + if (result.errors.size()) { + return result; } } + + result.requires_padding = rewriter.requires_padding(); + result.padding_predicate = rewriter.padding_predicate(); + // Step1: IterIndependenceChecker checks if the iterator are independent. - if (!rewriter.CheckMapping(results, require_bijective)) { - return Array(); + if (!rewriter.CheckMapping(result.indices, require_bijective)) { + if (require_bijective) { + result.errors.push_back("Index mapping does not form a bijective transform."); + } else { + result.errors.push_back("Mapped indices are not independent."); + } + return result; } - return results; + return result; } TVM_REGISTER_GLOBAL("arith.DetectIterMap") @@ -1113,50 +1240,230 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { } } -PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, - const PrimExpr& orig) { - // floordiv(x*scale, rhs) - if (is_one(rhs)) return std::move(lhs); +IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend) { + if (dividend->IsInstance()) { + auto split = Downcast(dividend); + return IterSumExpr({split}, make_zero(split.dtype())); + } else if (dividend->IsInstance()) { + auto opt_fused = TryFuseIters(Downcast(dividend)); + if (!opt_fused) { + ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(dividend) + << ", can't be written as a single fused IterSum"; + return IterSumExpr(); + } + + IterSumExpr fused = opt_fused.value(); + + ICHECK_EQ(fused->args.size(), 1U); + return fused; + } else { + LOG(FATAL) << "Unsupported subclass of IterMarkExpr"; + return IterSumExpr(); + } +} + +IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExpr base, + PrimExpr divisor) { + // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor + // If FloorMod: (((source//lower_factor) % extent) + base) % divisor + + IterSumExpr lookup_key({split}, 0); + + auto modified_divisor = [&]() { + if (update_iterator_padding_) { + return divisor; + } + + auto it = padded_iter_map_.find(lookup_key); + if (it == padded_iter_map_.end()) { + return divisor; + } + + const std::vector& divisors = it->second.divisors; + PrimExpr largest_divisor = divisor; + for (const auto& other : divisors) { + if (CanProveDivisible(other, largest_divisor)) { + // New one is bigger, use it + largest_divisor = other; + } else if (CanProveDivisible(largest_divisor, other)) { + // Current is bigger, keep it + } else { + ErrorLogger(this) << "Iterator appears in multiple terms with incompatible divisors " + << tvm::PrettyPrint(largest_divisor) << " and " + << tvm::PrettyPrint(other); + } + } + return largest_divisor; + }(); + + divisor = modified_divisor; + + // First, adding any padding that is on the lower side of a + // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0 + // when iter==0. + + PrimExpr left_pad; + + if (is_zero(base)) { + // Padding on the left is unnecessary if base is known to be zero. + left_pad = make_zero(base->dtype); + } else { + left_pad = floormod(base, divisor); + } + + // Next, adding any padding that is on the upper side of a + // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, divisor) == 0 + // when iter==extent. + + PrimExpr right_edge = left_pad + split->extent; + PrimExpr right_pad; + + if (CanProveDivisible(right_edge, divisor)) { + // Padding on the right is unnecessary if the extent is a multiple of + // the divisor. + right_pad = 0; + } else { + right_pad = floormod(-right_edge, divisor); + } + + if (is_zero(left_pad) && is_zero(right_pad)) { + return split; + } + + if (update_iterator_padding_) { + // In the first pass, the primary goal is to collect all the divisors + // that may be used for padding. These will impact the divisor used + // to determine padding in the second pass. + IterPaddingInfo& info = padded_iter_map_[lookup_key]; + + info.unpadded = split; + info.divisors.push_back(divisor); + + // If an iterator contains padding, it may require a subsequent + // iterator to also require padding. Therefore, introduce a variable + // for the extent. + + // PrimExpr padded_extent = left_pad + split->extent + right_pad; + // PrimExpr padded_extent = Var("N", divisor.dtype())*divisor; + PrimExpr padded_extent = info.var_extent; + + IterSumExpr as_sum({split}, base); + IterMark mark(as_sum, padded_extent); + IterSplitExpr new_split(mark); + + return new_split; + } + + // Any padding that is required during parsing should have been found + // during the first pass that determines the GCD. + auto it = padded_iter_map_.find(lookup_key); + if (it == padded_iter_map_.end()) { + ErrorLogger(this) << "Dividend has extent " << tvm::PrettyPrint(split->extent) << " and offset " + << tvm::PrettyPrint(base) << ", which requires padding for divisor " + << tvm::PrettyPrint(divisor) << "."; + return IterSplitExpr(); + } + IterPaddingInfo& info = it->second; + + if (info.padded.defined()) { + // A previous visit already applied padding to this iterator. + // (e.g. Visiting `(i+1)//4`, then visiting `(i+1)%4`). + ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad)); + ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad)); + + return info.padded; + } + + // This is the first encounter with the iterator during the second pass. + + IterMark mark(IterSumExpr({split}, base), left_pad + split->extent + right_pad); + info.padded = IterSplitExpr(mark); + info.left_pad = left_pad; + info.right_pad = right_pad; + + auto left_padding_introduced = (left_pad != 0); + // Equivalent to (0 <= split < left_pad), but easier to simplify in + // terms of the transformed variables. + auto left_padding_predicate = + left_padding_introduced && + (floordiv(info.padded, divisor) == 0 && floormod(info.padded, divisor) < left_pad); + + PrimExpr nparts = ceildiv(right_edge, divisor); + + auto right_padding_introduced = (right_pad != 0); + + // Equivalent to (right_edge <= split < right_edge+right_pad), but + // easier to simplify in terms of the transformed variables. + auto right_padding_predicate = + right_padding_introduced && (floordiv(info.padded, divisor) == nparts - 1 && + floormod(info.padded, divisor) >= floormod(right_edge, divisor)); + + requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); + padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); + + return info.padded; +} + +PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { + // (lhs + base) // rhs + + if (is_one(rhs)) { + if (is_zero(base)) { + // floordiv(x, 1) = x + return std::move(lhs); + } else { + // floordiv(x+y, 1) = x+y + return IterSumExpr({lhs}, base); + } + } + if (!is_one(lhs->scale)) { - if (CanProveDivisible(lhs->scale, rhs)) { + if (CanProveDivisible(lhs->scale, rhs) && is_zero(base)) { // floordiv(x*c1*c2, c2) = x*c1, c1=scale/rhs lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); return std::move(lhs); - } else if (CanProveDivisible(rhs, lhs->scale)) { + } else if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { + // floordiv(x*c1*c2 + y*c2, c2) = x*c1 + y, c1=scale/rhs + lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); + return IterSumExpr({lhs}, floordiv(base, rhs)); + } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale rhs = floordiv(rhs, lhs->scale); lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); + } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { + // floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale + base = floordiv(base, lhs->scale); + rhs = floordiv(rhs, lhs->scale); + lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); } else { // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " << tvm::PrettyPrint(lhs->scale) << " and the divisor " << tvm::PrettyPrint(rhs) << " cannot be simplified to remove the scaling factor."; - return orig; + return PrimExpr(); } } // We handle scale!=1 in above code, hence we only consider floordiv(x, rhs) below - // where x=floormod(floordiv(iter, lower_factor), extent) - if (CanProveDivisible(lhs->extent, rhs)) { - // floordiv(floormod(floordiv(iter, lower_factor), c1c2), c1) - // = floordiv(floormod(y, c1c2), c1), where y=floordiv(iter, lower_factor) - // = floordiv(floormod(sc1c2+tc1+u, c1c2), c1), where y=sc1c2+tc1+u, tlower_factor *= rhs; - ptr_lhs->extent = analyzer_->Simplify(floordiv(ptr_lhs->extent, rhs)); - return std::move(lhs); - } else { - // mark as unresolved. - ErrorLogger(this) << "Cannot represent as IterMap: the numerator's extent, " - << tvm::PrettyPrint(lhs->extent) << " is not a multiple of the divisor, " - << tvm::PrettyPrint(rhs) << "."; - return orig; - } + // where x=floormod(floordiv(iter, lower_factor), extent) + base + + IterSplitExpr padded = PadDividendToDivisor(lhs, base, rhs); + if (!padded.defined()) { + return PrimExpr(); + } + + // floordiv(floormod(floordiv(iter, lower_factor), c1c2), c1) + // = floordiv(floormod(y, c1c2), c1), where y=floordiv(iter, lower_factor) + // = floordiv(floormod(sc1c2+tc1+u, c1c2), c1), where y=sc1c2+tc1+u, tsource, + /* lower_factor = */ padded->lower_factor * rhs, + /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), + /* scale = */ padded->scale); } PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { @@ -1187,65 +1494,60 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { return GetRef(op); } - if (a->IsInstance()) { - IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - IterSumExpr sum = opt.value(); - if (!is_zero(sum->base)) { - ErrorLogger(this) << "Cannot represent as an IterMap: the dividend in " - << tvm::PrettyPrint(GetRef(op)) << " has a non-zero offset."; - return GetRef(op); - } - ICHECK_EQ(sum->args.size(), 1U); - return SplitFloorDivConst(sum->args[0], b, GetRef(op)); - } else { - ErrorLogger(this) << "Cannot represent as an IterMap: the dividend in " - << tvm::PrettyPrint(GetRef(op)) - << " cannot be represented as a single fused iterator"; - return GetRef(op); - } - } else { - ICHECK(a->IsInstance()); - IterSplitExpr ret = Downcast(std::move(a)); - return SplitFloorDivConst(ret, b, GetRef(op)); + IterSumExpr preprocessed = PreprocessDividend(Downcast(a)); + if (!preprocessed.defined()) { + return GetRef(op); + } + PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); + if (!remainder.defined()) { + return GetRef(op); } + return remainder; } -PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, - const PrimExpr& orig) { - // floormod(x*scale, rhs) - if (is_one(rhs)) return make_zero(lhs->dtype); +PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { + // (lhs + base) % rhs + + if (is_one(rhs)) { + // floormod(x, 1) = 0 + return make_zero(lhs->dtype); + } + if (!is_one(lhs->scale)) { - // floormod(x*c1*c2, c1) = 0 - if (CanProveDivisible(lhs->scale, rhs)) { + if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { + // floormod(x*c1*c2, c1) = 0 return make_zero(lhs->dtype); - } else if (CanProveDivisible(rhs, lhs->scale)) { + } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale rhs = floordiv(rhs, lhs->scale); + } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { + // floormod(x*c1 + y*c1, c1*c2) = (floormod(x+y, c2)) * c1, where c2 = rhs/scale + rhs = floordiv(rhs, lhs->scale); + base = floordiv(base, lhs->scale); } else { // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the left-hand side of FloorMod has a scaling factor, " << tvm::PrettyPrint(lhs->scale) << " and the right-hand " << tvm::PrettyPrint(rhs) << " cannot be used to simplify out the scaling factor."; - return orig; + return PrimExpr(); } } - // floormod(x, rhs) where x=floormod(floordiv(iter, lower_factor), extent) - if (CanProveDivisible(lhs->extent, rhs)) { - // floormod(floormod(floordiv(iter, lower_factor), c1c2), c1) - // = floormod(floordiv(iter, lower_factor), c1), where c1=rhs - lhs.CopyOnWrite()->extent = rhs; - return std::move(lhs); - } else { - // mark as unresolved. - ErrorLogger(this) << "Cannot represent as IterMap: the left-hand side of FloorMod has extent " - << tvm::PrettyPrint(lhs->extent) - << " which does not evenly divide the right-hand side, " - << tvm::PrettyPrint(rhs) << "."; - return orig; + // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below + // where x=floormod(floordiv(iter, lower_factor), extent) + base + + IterSplitExpr padded = PadDividendToDivisor(lhs, base, rhs); + if (!padded.defined()) { + return PrimExpr(); } + + // floormod(floormod(floordiv(iter, lower_factor), c1c2), c1) + // = floormod(floordiv(iter, lower_factor), c1), where c1=rhs + return IterSplitExpr(padded->source, + /* lower_factor = */ padded->lower_factor, + /* extent = */ rhs, + /* scale = */ padded->scale); } PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { @@ -1276,27 +1578,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { return GetRef(op); } - if (a->IsInstance()) { - IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - IterSumExpr sum = opt.value(); - if (!is_zero(sum->base)) { - ErrorLogger(this) << "Cannot represent as an IterMap: the left-hand side of FloorMod in " - << tvm::PrettyPrint(GetRef(op)) << " has a non-zero offset."; - return GetRef(op); - } - return SplitFloorModConst(sum->args[0], b, GetRef(op)); - } else { - ErrorLogger(this) << "Cannot represent as an IterMap: the left-hand side of FloorMod in " - << tvm::PrettyPrint(GetRef(op)) - << " cannot be represented as a single fused iterator"; - return GetRef(op); - } - } else { - ICHECK(a->IsInstance()); - IterSplitExpr ret = Downcast(std::move(a)); - return SplitFloorModConst(ret, b, GetRef(op)); + IterSumExpr preprocessed = PreprocessDividend(Downcast(a)); + if (!preprocessed.defined()) { + return GetRef(op); } + + PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); + if (!remainder.defined()) { + return GetRef(op); + } + return remainder; } /*! * \brief Given an expression that may contain IterVarMapExpr, transform it to normal PrimExpr. @@ -1351,6 +1642,21 @@ class IterMapToExprNormalizer : public ExprMutator { Analyzer* analyzer_; }; +bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { + const auto* clhs = lhs.as(); + const auto* crhs = rhs.as(); + if (clhs && crhs) { + return clhs->value % crhs->value == 0; + } + + IterMapToExprNormalizer normalizer(analyzer_); + PrimExpr dividend = normalizer.Convert(lhs); + PrimExpr divisor = normalizer.Convert(rhs); + + return analyzer_->CanProveEqual(dividend, divisor) || + analyzer_->CanProve(floormod(dividend, divisor) == 0); +} + PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { arith::Analyzer analyzer; IterMapToExprNormalizer normalizer(&analyzer); From a48abbb6ed8fd61f178075043f917b42ce4e1729 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 22 Apr 2022 14:06:26 -0500 Subject: [PATCH 04/10] [IndexMap] Implemented IndexMap::NonSurjectiveInverse Allow non-surjective transformations, with DetectIterMap used to determine the minimum padding to insert. Returns the inverse function, along with a predicate that identifies padding indices. The predicate is in terms of the transformed variables. --- include/tvm/tir/index_map.h | 20 ++++++++++-- src/tir/ir/index_map.cc | 64 +++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 195bf7e02ce3..b6faa67ab53a 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -31,6 +31,8 @@ #include #include +#include + namespace tvm { namespace tir { @@ -141,12 +143,24 @@ class IndexMap : public ObjectRef { * * TODO(Lunderberg): Look into allowing non-bijective * transformations. If injective, the inverse mapping could still - * be generated with some predicate. If non-injective, could - * simplify the implementation of other optimizations (e.g. double - * buffering as a map `lambda *indices: [buffer_loop%2, *indices]`). + * be generated with some predicate (see NonSurjectiveInverse). If + * non-injective, could simplify the implementation of other + * optimizations (e.g. double buffering as a map `lambda *indices: + * [buffer_loop%2, *indices]`). */ IndexMap Inverse(Array initial_ranges) const; + /*! \brief Generate the inverse mapping. + * + * Determine the inverse, where the output range may contain + * addresses that do not correspond to an address in the input + * range. + * + * \return The inverted index map, along with the predicate for + * which the inverse maps to a valid range. + */ + std::pair NonSurjectiveInverse(Array initial_ranges) const; + TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); }; diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 93f308b42d74..64314728f823 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -50,6 +50,70 @@ IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(A return IndexMap(initial_indices, func(initial_indices)); } +std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges) const { + // Dummy variables to represent the inverse's inputs. + Array output_vars; + for (size_t i = 0; i < (*this)->final_indices.size(); i++) { + PrimExpr index = (*this)->final_indices[i]; + // TODO(Lunderberg): Better names for these variables. A variable + // that is passed through unmodified (`index` is an element of + // `initial_indices`) should use that input index's name. A pair + // of output indices variables split from a single input index + // should be named (X.outer,X.inner). + std::stringstream ss; + ss << "axis" << i; + Var var_index(ss.str(), index.dtype()); + output_vars.push_back(var_index); + } + + // Dummy ranges for the extent of each input. + Map input_iters; + ICHECK_EQ((*this)->initial_indices.size(), initial_ranges.size()); + for (size_t i = 0; i < initial_ranges.size(); i++) { + input_iters.Set((*this)->initial_indices[i], initial_ranges[i]); + } + + // Unpack the output indices into linear combinations of the initial + // indices. + arith::Analyzer analyzer; + auto padded_iter_map = + DetectPaddedIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /* require_bijective = */ false, &analyzer, + /* simplify_trivial_iterators = */ false); + CHECK(padded_iter_map.errors.empty()) << "Could not parse mapping as sum of iterators. " + << "Error: " << padded_iter_map.errors[0]; + + // Determine expressions for the input variables, in terms of the + // output variables. + Map inverse_exprs_map = InverseAffineIterMap( + padded_iter_map.indices, Array(output_vars.begin(), output_vars.end())); + + // Unpack the map to an array, maintaining the same parameter order. + Array inverse_exprs; + for (const auto& index : (*this)->initial_indices) { + inverse_exprs.push_back(inverse_exprs_map.at(index)); + } + + PrimExpr padding_predicate = padded_iter_map.padding_predicate; + padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); + padding_predicate = Substitute(padding_predicate, inverse_exprs_map); + + { + auto output_ranges = (*this)->MapRanges(initial_ranges); + ICHECK_EQ(output_ranges.size(), output_vars.size()); + + arith::Analyzer analyzer; + for (size_t i = 0; i < output_vars.size(); ++i) { + analyzer.Bind(output_vars[i], output_ranges[i]); + } + + // Additional simplification steps required to unwrap nested floordiv/floormod + padding_predicate = analyzer.Simplify(padding_predicate, 10); + } + + return {IndexMap(output_vars, inverse_exprs), padding_predicate}; +} + IndexMap IndexMap::Inverse(Array initial_ranges) const { // Dummy variables to represent the inverse's inputs. Array output_vars; From 339550445a7fb0e73e29a57c5214c4f9f068f6cf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 6 May 2022 11:15:03 -0500 Subject: [PATCH 05/10] [IndexMap] Exposed methods to python - `IndexMap::Inverse` exposed as `IndexMap.inverse` - `IndexMap::MapShape` exposed as `IndexMap.map_shape` - `IndexMap::NonSurjectiveInverse` exposed as `IndexMap.non_surjective_inverse` --- python/tvm/tir/function.py | 85 ++++++++++++++++++++++++++++++++++++-- src/tir/ir/index_map.cc | 8 ++++ 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 643bbca8eebd..a1a790d48a96 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,13 +16,14 @@ # under the License. """Function data types.""" -from typing import Callable, List, Mapping, Optional, Union +from typing import Callable, List, Mapping, Optional, Union, Tuple import inspect +import tvm import tvm._ffi import tvm.runtime from tvm.runtime import Object -from tvm.ir import BaseFunc +from tvm.ir import BaseFunc, Range from .buffer import Buffer from .expr import Var, PrimExpr from . import _ffi_api @@ -301,7 +302,7 @@ def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]: Parameters ---------- - indices : List[PriExpr] + indices : List[PrimExpr] The indices to be mapped Returns @@ -310,3 +311,81 @@ def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]: The mapped indices """ return _ffi_api.IndexMapMapIndices(self, indices) + + def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]: + """Apply the index map to a buffer shape + + Parameters + ---------- + shape : List[PrimExpr] + The buffer shape to be mapped + + Returns + ------- + result : List[PrimExpr] + The mapped shape + """ + return _ffi_api.IndexMapMapShape(self, shape) + + def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": + """Return the inverse of the map + + Throws an error if the function is not bijective. + + Paramters + --------- + shape: List[Union[Range,PrimExpr]] + + The region over which the inverse should be determined. + Used for validating that the mapping is bijective over + this range. + + Returns + ------- + inverse : IndexMap + + The inverse + """ + + shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape] + return _ffi_api.IndexMapInverse(self, shape) + + def non_surjective_inverse( + self, shape: List[Union[Range, PrimExpr]] + ) -> Tuple["IndexMap", PrimExpr]: + """Return the inverse of the map + + Can be applied to transformations that introduce padding. + + Examples + -------- + + Before unroll, in TensorIR, the IR is: + + .. code-block:: python + + index_map = IndexMap.from_func(lambda i: [i//4, i%4]) + inverse_map, predicate = index_map.non_surjective_inverse([14]) + inverse_map + + .. code-block:: python + + index_map = IndexMap.from_func(lambda i: [i//4, i%4]) + + Paramters + --------- + shape: List[Union[Range,PrimExpr]] + + The region over which the inverse should be determined. + Used for determining the predicate. + + Returns + ------- + result : Tuple[IndexMap, PrimExpr] + + The inverse, and a predicate for which the inverse maps to + a valid index in the input range. + """ + + shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape] + return _ffi_api.IndexMapNonSurjectiveInverse(self, shape) diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 64314728f823..4c0a7d3508c1 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -266,6 +266,14 @@ TVM_REGISTER_GLOBAL("tir.IndexMap") }); TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method(&IndexMapNode::MapIndices); +TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_method(&IndexMapNode::MapShape); +TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse); + +TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") + .set_body_typed([](IndexMap forward, Array initial_ranges) { + auto result = forward.NonSurjectiveInverse(initial_ranges); + return Array{result.first, result.second}; + }); } // namespace tir } // namespace tvm From 057a9e1fdf96de31aaa1d1c8e38b35c27574e887 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 22 Apr 2022 13:57:14 -0500 Subject: [PATCH 06/10] [IndexMap] Extracted _assert_equal_index_map into class method In preparation for adding additional tests for the IndexMap class, which will require this functionality. --- python/tvm/tir/function.py | 30 +++++++++++++++++++ .../unittest/test_tir_schedule_analysis.py | 12 ++------ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index a1a790d48a96..7029d6f9f809 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -297,6 +297,36 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): final_indices = mapping_function(*args) return IndexMap(args, final_indices) + def is_equivalent_to(self, other_map: "IndexMap") -> bool: + """Return if the index maps are equivalent. + + Parameters + ---------- + other_map: IndexMap + + The IndexMap to which the comparison should be made. + + Returns + ------- + is_equivalent: bool + + True if the two mappings represent the same + transformation, otherwise False + """ + if len(self.initial_indices) != len(other_map.initial_indices): + return False + if len(self.final_indices) != len(other_map.final_indices): + return False + + analyzer = tvm.arith.Analyzer() + + mapped_other_final_indices = other_map.map_indices(self.initial_indices) + for self_index, other_index in zip(self.final_indices, mapped_other_final_indices): + if not analyzer.can_prove_equal(self_index, other_index): + return False + + return True + def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]: """Apply the index map to a set of indices diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 10371d3ccaf1..19be0b8699ac 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -48,14 +48,6 @@ def _make_loops(loop_vars: List[Var], extents: List[int]) -> List[For]: ] -def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: - iters_1 = map1.map_indices(map2.initial_indices) - iters_2 = map2.final_indices - assert len(iters_1) == len(iters_2) - for iter1, iter2 in zip(iters_1, iters_2): - assert expr_deep_equal(iter1, iter2) - - def test_suggest_index_map_simple(): i, j = _make_vars("i", "j") index_map = suggest_index_map( @@ -78,7 +70,7 @@ def test_suggest_index_map_simple(): floormod(y, 16), ], ) - _assert_equal_index_map(index_map, expected_index_map) + assert index_map.is_equivalent_to(expected_index_map) def test_suggest_index_map_bijective(): @@ -98,7 +90,7 @@ def test_suggest_index_map_bijective(): floordiv(x, 2), ], ) - _assert_equal_index_map(index_map, expected_index_map) + assert index_map.is_equivalent_to(expected_index_map) @tvm.script.ir_module From 882b02e13c1956ec5aaca8d2f3145d76e2a743ea Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 22 Apr 2022 14:05:23 -0500 Subject: [PATCH 07/10] [IndexMap] Added unit tests for new behavior --- tests/python/unittest/test_index_map.py | 189 ++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/python/unittest/test_index_map.py diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py new file mode 100644 index 000000000000..dff1bebbf2e6 --- /dev/null +++ b/tests/python/unittest/test_index_map.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm.tir import IndexMap +from tvm.ir import assert_structural_equal + + +def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: + + iters_1 = map1.map_indices(map2.initial_indices) + iters_2 = map2.final_indices + assert len(iters_1) == len(iters_2) + + analyzer = tvm.arith.Analyzer() + for iter1, iter2 in zip(iters_1, iters_2): + assert analyzer.can_prove_equal(iter1, iter2) + + +def test_index_mapping(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + assert_structural_equal(index_map.map_indices([0]), [0, 0]) + assert_structural_equal(index_map.map_indices([3]), [0, 3]) + assert_structural_equal(index_map.map_indices([4]), [1, 0]) + assert_structural_equal(index_map.map_indices([42]), [10, 2]) + + +def test_shape_mapping(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + assert_structural_equal(index_map.map_shape([4]), [1, 4]) + assert_structural_equal(index_map.map_shape([16]), [4, 4]) + + assert_structural_equal(index_map.map_shape([14]), [4, 4]) + + +def test_inverse(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + expected_inverse = IndexMap.from_func(lambda i, j: [4 * i + j]) + + assert index_map.inverse([16]).is_equivalent_to(expected_inverse) + + +def test_nonbijective_inverse_gives_error(): + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + + with pytest.raises(tvm.TVMError): + index_map.inverse([14]) + + +dynamic_N = tvm.tir.Var("N", "int32") +padding_test_case = tvm.testing.parameter( + by_dict={ + "no_padding": dict( + forward=lambda i: [i // 4, i % 4], + inverse=lambda i, j: [4 * i + j], + pre_shape=[16], + post_shape=[4, 4], + padding=lambda i, j: tvm.runtime.convert(False), + ), + "right_padding": dict( + forward=lambda i: [i // 4, i % 4], + inverse=lambda i, j: [4 * i + j], + pre_shape=[15], + post_shape=[4, 4], + padding=lambda i, j: tvm.tir.And(i == 3, j >= 3), + ), + "left_padding": dict( + forward=lambda i: [(i + 1) // 4, (i + 1) % 4], + inverse=lambda i, j: [4 * i + j - 1], + pre_shape=[15], + post_shape=[4, 4], + padding=lambda i, j: tvm.tir.And(i == 0, j < 1), + ), + "left_and_right_padding": dict( + forward=lambda i: [(i + 1) // 4, (i + 1) % 4], + inverse=lambda i, j: [4 * i + j - 1], + pre_shape=[14], + post_shape=[4, 4], + padding=lambda i, j: tvm.tir.Or( + tvm.tir.And(i == 0, j < 1), + tvm.tir.And(i == 3, j >= 3), + ), + ), + "dynamic_size": dict( + forward=lambda i: [i // 4, i % 4], + inverse=lambda i, j: [4 * i + j], + pre_shape=[dynamic_N], + post_shape=[(dynamic_N - 1) // 4 + 1, 4], + padding=lambda i, j: tvm.tir.And( + dynamic_N % (-4) != 0, + tvm.tir.And(i == (dynamic_N + 3) // 4 - 1, j >= dynamic_N % 4), + ), + ), + "2d_padding": dict( + forward=lambda i, j: [(i + 1) // 4, (j + 5) // 8, (i + 1) % 4, (j + 5) % 8], + inverse=lambda i_outer, j_outer, i_inner, j_inner: [ + 4 * i_outer + i_inner - 1, + 8 * j_outer + j_inner - 5, + ], + pre_shape=[14, 31], + post_shape=[ + 4, # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 + 5, # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 + 4, # Range of iter%4 + 8, # Range of iter%8 + ], + padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( + tvm.tir.Or( + tvm.tir.And(i_outer == 0, i_inner < 1), + tvm.tir.And(i_outer == 3, i_inner >= 3), + ), + tvm.tir.Or( + tvm.tir.And(j_outer == 0, j_inner < 5), + tvm.tir.And(j_outer == 4, j_inner >= 4), + ), + ), + ), + "multiple_right_padding": dict( + forward=lambda i: [i // 32, (i // 4) % 8, i % 4], + inverse=lambda i, j, k: [32 * i + 4 * j + k], + pre_shape=[116], + post_shape=[4, 8, 4], + padding=lambda i, j, k: tvm.tir.And(i == 3, 4 * j + k >= 20), + ), + "multiple_right_padding_transpose": dict( + forward=lambda i: [(i // 4) % 8, i // 32, i % 4], + inverse=lambda j, i, k: [32 * i + 4 * j + k], + pre_shape=[116], + post_shape=[8, 4, 4], + padding=lambda j, i, k: tvm.tir.And(i == 3, 4 * j + k >= 20), + ), + "multiple_left_padding": dict( + forward=lambda i: [(i + 5) // 32, ((i + 5) // 4) % 8, (i + 5) % 4], + inverse=lambda i, j, k: [32 * i + 4 * j + k - 5], + pre_shape=[123], + post_shape=[4, 8, 4], + padding=lambda i, j, k: tvm.tir.And(i == 0, j * 4 + k < 5), + ), + "multiple_left_padding_with_transpose": dict( + forward=lambda i: [((i + 5) // 4) % 8, (i + 5) // 32, (i + 5) % 4], + inverse=lambda j, i, k: [32 * i + 4 * j + k - 5], + pre_shape=[123], + post_shape=[8, 4, 4], + padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5), + ), + } +) + + +def test_nonsurjective_inverse(padding_test_case): + index_map = IndexMap.from_func(padding_test_case["forward"]) + + inverse, padding_predicate = index_map.non_surjective_inverse(padding_test_case["pre_shape"]) + expected_inverse = IndexMap.from_func(padding_test_case["inverse"]) + assert inverse.is_equivalent_to(expected_inverse) + + post_shape = index_map.map_shape(padding_test_case["pre_shape"]) + tvm.ir.assert_structural_equal(post_shape, padding_test_case["post_shape"]) + + expected_predicate = padding_test_case["padding"](*inverse.initial_indices) + + # Can't use analyzer.can_prove_equal, because it can't simplify + # expressions like `(4*i+j >= 14) - (4*i+j >= 14)`. + analyzer = tvm.arith.Analyzer() + expected_predicate = analyzer.simplify(expected_predicate) + padding_predicate = analyzer.simplify(padding_predicate) + tvm.ir.assert_structural_equal(padding_predicate, expected_predicate) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) From 7d3d540d47791d83f3daadc77c65273679f108d4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 6 May 2022 14:08:31 -0500 Subject: [PATCH 08/10] Re-enabled divisibility check in CheckMapping Initially disabled as dynamic shapes resulted in padded lengths whose divisiblity couldn't be proven. Re-enabled along with a simplification rule to resolve it. --- src/arith/iter_affine_map.cc | 15 +++------------ src/arith/rewrite_simplify.cc | 4 ++++ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 829a32cddba7..f1a1b661c55e 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -578,18 +578,9 @@ class IterMapRewriter : public ExprMutator { return Array(); } } else { - // This still requires that the splits can be part of a - // bijective transformation, even if they are not sufficient to - // form a bijective transformation. - // - // TODO: Why is this condition used, and what is it intended to allow? - // - // if (!CanProveDivisible(mark->extent, expected_lower_factor)) { - // std::cout << "Expected resulting lower factor " << expected_lower_factor - // << " to be divisible by the extent of the parent mark " << mark->extent - // << std::endl; - // return Array(); - // } + if (!CanProveDivisible(mark->extent, expected_lower_factor)) { + return Array(); + } } return Array(iters.rbegin(), iters.rend()); } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 0f7aa4c8a978..4d8b6ff769cf 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -448,6 +448,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y); TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y); + // Two representations of const*ceildiv(x, c1) + TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c2), c1) * c1, x - floormod(x, c2), + c1.Eval()->value == -c2.Eval()->value); + // canonicalization TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); From c39cd4ea37ea6402c1994ee8832d8aa59edac47d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 10 May 2022 11:45:36 -0500 Subject: [PATCH 09/10] Fixed breakage in compute_at primitive --- src/arith/iter_affine_map.cc | 89 ++++++++++++++----------- src/tir/schedule/state.cc | 1 + tests/python/unittest/test_index_map.py | 2 +- 3 files changed, 52 insertions(+), 40 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index f1a1b661c55e..a012b6e80c08 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -28,6 +28,8 @@ #include #include +#include + #include "../support/utils.h" #include "const_fold.h" #include "pattern_match.h" @@ -334,7 +336,12 @@ class IterMapRewriter : public ExprMutator { // padding previously defined for the split using UpdatePadding can be // used. If no such previous padding exists, return an empty // IterMark. - IterSplitExpr PadDividendToDivisor(IterSplitExpr split, PrimExpr base, PrimExpr divisor); + // + // Returns a pair of IterSplit that represents (split+base) in a + // form that can be dividied by divisors, and PrimExpr that + // represents the left padding applied to split. + std::pair PadDividendToDivisor(IterSplitExpr split, PrimExpr base, + PrimExpr divisor); friend struct ErrorLogger; @@ -365,10 +372,7 @@ class IterMapRewriter : public ExprMutator { }; struct IterPaddingInfo { - IterPaddingInfo() : var_extent("extent") {} // Used and collected during first pass - IterSplitExpr unpadded; - Var var_extent; std::vector divisors; // Defined on first encounter in second pass @@ -420,9 +424,12 @@ class IterMapRewriter : public ExprMutator { // input iter marks std::vector input_marks_; - // Map from an IterSumExpr containing padding to the IterMark - // representing the padded version. - std::unordered_map padded_iter_map_; + // Map from a normal PrimExpr to the padded iterator information for + // it. This is necessary for introducing the same padding in all + // usage of an input iterator. (e.g. (i-1) occurring in the + // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be + // left-padded by 31 for each occurrence.) + std::unordered_map padded_iter_map_; /* If allow_padding_ is true, allow the extents of the IterMap to be * padded beyond the original iterators. @@ -1253,12 +1260,13 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend) { } } -IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExpr base, - PrimExpr divisor) { +std::pair IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, + PrimExpr base, + PrimExpr divisor) { // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor // If FloorMod: (((source//lower_factor) % extent) + base) % divisor - IterSumExpr lookup_key({split}, 0); + PrimExpr lookup_key = split; auto modified_divisor = [&]() { if (update_iterator_padding_) { @@ -1299,7 +1307,7 @@ IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExp // Padding on the left is unnecessary if base is known to be zero. left_pad = make_zero(base->dtype); } else { - left_pad = floormod(base, divisor); + left_pad = analyzer_->Simplify(floormod(base, divisor)); } // Next, adding any padding that is on the upper side of a @@ -1314,11 +1322,11 @@ IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExp // the divisor. right_pad = 0; } else { - right_pad = floormod(-right_edge, divisor); + right_pad = analyzer_->Simplify(floormod(-right_edge, divisor)); } if (is_zero(left_pad) && is_zero(right_pad)) { - return split; + return {split, left_pad}; } if (update_iterator_padding_) { @@ -1327,22 +1335,15 @@ IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExp // to determine padding in the second pass. IterPaddingInfo& info = padded_iter_map_[lookup_key]; - info.unpadded = split; info.divisors.push_back(divisor); - // If an iterator contains padding, it may require a subsequent - // iterator to also require padding. Therefore, introduce a variable - // for the extent. + PrimExpr padded_extent = left_pad + split->extent + right_pad; - // PrimExpr padded_extent = left_pad + split->extent + right_pad; - // PrimExpr padded_extent = Var("N", divisor.dtype())*divisor; - PrimExpr padded_extent = info.var_extent; - - IterSumExpr as_sum({split}, base); + IterSumExpr as_sum({split}, left_pad); IterMark mark(as_sum, padded_extent); IterSplitExpr new_split(mark); - return new_split; + return {new_split, left_pad}; } // Any padding that is required during parsing should have been found @@ -1352,7 +1353,7 @@ IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExp ErrorLogger(this) << "Dividend has extent " << tvm::PrettyPrint(split->extent) << " and offset " << tvm::PrettyPrint(base) << ", which requires padding for divisor " << tvm::PrettyPrint(divisor) << "."; - return IterSplitExpr(); + return {IterSplitExpr(), left_pad}; } IterPaddingInfo& info = it->second; @@ -1362,12 +1363,12 @@ IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExp ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad)); ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad)); - return info.padded; + return {info.padded, left_pad}; } // This is the first encounter with the iterator during the second pass. - - IterMark mark(IterSumExpr({split}, base), left_pad + split->extent + right_pad); + IterSumExpr as_sum({split}, left_pad); + IterMark mark(as_sum, left_pad + split->extent + right_pad); info.padded = IterSplitExpr(mark); info.left_pad = left_pad; info.right_pad = right_pad; @@ -1376,8 +1377,8 @@ IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExp // Equivalent to (0 <= split < left_pad), but easier to simplify in // terms of the transformed variables. auto left_padding_predicate = - left_padding_introduced && - (floordiv(info.padded, divisor) == 0 && floormod(info.padded, divisor) < left_pad); + left_padding_introduced && (floordiv(info.padded, divisor) == floordiv(base, divisor) && + floormod(info.padded, divisor) < left_pad); PrimExpr nparts = ceildiv(right_edge, divisor); @@ -1385,14 +1386,14 @@ IterSplitExpr IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExp // Equivalent to (right_edge <= split < right_edge+right_pad), but // easier to simplify in terms of the transformed variables. - auto right_padding_predicate = - right_padding_introduced && (floordiv(info.padded, divisor) == nparts - 1 && - floormod(info.padded, divisor) >= floormod(right_edge, divisor)); + auto right_padding_predicate = right_padding_introduced && + (floordiv(info.padded, divisor) == floordiv(right_edge, divisor) && + floormod(info.padded, divisor) >= floormod(right_edge, divisor)); requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); - return info.padded; + return {info.padded, left_pad}; } PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { @@ -1439,7 +1440,9 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P // We handle scale!=1 in above code, hence we only consider floordiv(x, rhs) below // where x=floormod(floordiv(iter, lower_factor), extent) + base - IterSplitExpr padded = PadDividendToDivisor(lhs, base, rhs); + auto pair = PadDividendToDivisor(lhs, base, rhs); + IterSplitExpr padded = pair.first; + PrimExpr left_pad = pair.second; if (!padded.defined()) { return PrimExpr(); } @@ -1451,10 +1454,17 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P // = floormod(sc2+t, c2) // = floormod(floordiv(y, c1), c2) // = floormod(floordiv(iter, lower_factor*c1), c2), where c1=rhs, c2=extent/rhs - return IterSplitExpr(padded->source, - /* lower_factor = */ padded->lower_factor * rhs, - /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), - /* scale = */ padded->scale); + IterSplitExpr new_split(padded->source, + /* lower_factor = */ padded->lower_factor * rhs, + /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), + /* scale = */ padded->scale); + + auto new_base = floordiv(base - left_pad, rhs); + if (is_zero(new_base)) { + return std::move(new_split); + } else { + return IterSumExpr({new_split}, new_base); + } } PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { @@ -1528,7 +1538,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below // where x=floormod(floordiv(iter, lower_factor), extent) + base - IterSplitExpr padded = PadDividendToDivisor(lhs, base, rhs); + auto pair = PadDividendToDivisor(lhs, base, rhs); + IterSplitExpr padded = pair.first; if (!padded.defined()) { return PrimExpr(); } diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index eb43157d805a..3c11d2485332 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -109,6 +109,7 @@ bool ProducerCoversConsumer(const Array& buffer_shape, analyzer->canonical_simplify(consumed_region[i].max())); produced = arith::Intersect({produced, buffer_size}); consumed = arith::Intersect({consumed, buffer_size}); + if (!analyzer->CanProve((analyzer->canonical_simplify(produced.min() - consumed.min()) <= 0) && (analyzer->canonical_simplify(consumed.max() - produced.max()) <= 0))) { return false; diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index dff1bebbf2e6..a8f5204f0202 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -106,7 +106,7 @@ def test_nonbijective_inverse_gives_error(): post_shape=[(dynamic_N - 1) // 4 + 1, 4], padding=lambda i, j: tvm.tir.And( dynamic_N % (-4) != 0, - tvm.tir.And(i == (dynamic_N + 3) // 4 - 1, j >= dynamic_N % 4), + tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4), ), ), "2d_padding": dict( From 7708d9eee37b4067d41dc01f0c2309483f200aad Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 May 2022 09:25:55 -0500 Subject: [PATCH 10/10] Corrected typos/examples in docstring --- python/tvm/tir/function.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 7029d6f9f809..d84513e072d3 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -362,8 +362,8 @@ def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap": Throws an error if the function is not bijective. - Paramters - --------- + Parameters + ---------- shape: List[Union[Range,PrimExpr]] The region over which the inverse should be determined. @@ -387,23 +387,8 @@ def non_surjective_inverse( Can be applied to transformations that introduce padding. - Examples - -------- - - Before unroll, in TensorIR, the IR is: - - .. code-block:: python - - index_map = IndexMap.from_func(lambda i: [i//4, i%4]) - inverse_map, predicate = index_map.non_surjective_inverse([14]) - inverse_map - - .. code-block:: python - - index_map = IndexMap.from_func(lambda i: [i//4, i%4]) - - Paramters - --------- + Parameters + ---------- shape: List[Union[Range,PrimExpr]] The region over which the inverse should be determined. @@ -415,6 +400,16 @@ def non_surjective_inverse( The inverse, and a predicate for which the inverse maps to a valid index in the input range. + + Examples + -------- + + .. code-block:: python + + index_map = IndexMap.from_func(lambda i: [i//4, i%4]) + inverse_map, predicate = index_map.non_surjective_inverse([14]) + assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k]) + print(predicate) # Prints "(axis0==3) && (axis2 >= 2)" """ shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape]