From ae8f00ccb94f346e109f12f6fa13a3402d00d420 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 20 Jun 2023 19:20:27 -0400 Subject: [PATCH] [ARITH] Hotfix flaky test in padded matmul This PR hotfixes a flaky test in padded matmul --- src/arith/iter_affine_map.cc | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 377f8bb7c9b1..cf5281a6cff0 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -658,7 +658,7 @@ class IterMapRewriter : public ExprMutator { if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max.value() - base; } - Optional opt = TryFuseIters(expr, check_level_); + Optional opt = TryFuseIters(expr, check_level_, false); ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 if (opt.defined() && is_one(opt.value()->args[0]->scale)) { @@ -722,11 +722,7 @@ class IterMapRewriter : public ExprMutator { IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter if (expr->args.size() < 1) return expr; - if (auto opt = TryCombineSplitFromSameSource(expr)) { - expr = opt.value(); - if (expr->args.size() < 1) return expr; - } - Optional opt = TryFuseIters(expr, check_level_); + Optional opt = TryFuseIters(expr, check_level_, true); if (opt.defined()) { return opt.value(); } else { @@ -996,9 +992,18 @@ class IterMapRewriter : public ExprMutator { * Try to normalize IterSum into a fused IterMark * \param expr The input sum. * \param check_level The check level if iter mapping. - * \return The sum with the fused IterMark and extra offset if succeed. + * \param allow_early_skip Whether do we allow early skip if expr is simple + * (this may cause us to return parameters that are not canonically wrapped as + * IterSum(IterMark)) \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level) { + Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level, + bool allow_early_skip) { + if (auto opt = TryCombineSplitFromSameSource(expr)) { + expr = opt.value(); + if (expr->args.size() <= 1 && allow_early_skip) { + return expr; + } + } // select the iterators in order std::vector visited(expr->args.size(), false); int base_index = FindBaseIter(expr, visited, NullOpt); @@ -1554,12 +1559,8 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o return IterSumExpr(); } else if (sum->args.size() == 1) { return sum; - } else if (auto opt = TryCombineSplitFromSameSource(sum)) { - if (opt.value()->args.size() == 1) { - return opt.value(); - } } - auto opt_fused = TryFuseIters(sum, check_level_); + auto opt_fused = TryFuseIters(sum, check_level_, true); if (!opt_fused) { ErrorLogger(this) << "Dividend " << original_dividend << ", can't be written as a single fused IterSum";