diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index af1128aa273c..89a803d058e4 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -486,7 +486,7 @@ class IterMapRewriter : public ExprMutator { bool requires_padding_{false}; // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly - // an extra offset) + // an extra offset). The normal form always has minimum value of zero. // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 // Then, flattened form = IterSum(IterSplit(i, scale=9), @@ -497,6 +497,7 @@ class IterMapRewriter : public ExprMutator { // IterSplit(k, scale=1)), // extent=9) // scale=1)) + // offset = 0 // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: 1 <= j*2 + k < 9 // Then, flattened form = IterSum(IterSplit(i, scale=8), @@ -507,9 +508,15 @@ class IterMapRewriter : public ExprMutator { // IterSplit(k, scale=1), base=-1), // extent=9-1) // scale=1), - // base=1) + // base=0) + // offset = 1 std::unordered_map sum_fuse_map_; // The map for sum that maps normal form to flattened form + // For sum_fuse_map_ and flattened_map_ the following invariants hold: + // for any IterSumExpr e in the flattened_form, we have + // iter_mark, mark_offset = sum_fuse_map_[e] + // flattened_map_[normal_form] = e where normal_form = iter_mark->args[0] and + // iter_mark->args.size() = 1 std::unordered_map flattened_map_; // The flattened forms of constrained iters std::vector constrained_iters_flattened_; @@ -685,7 +692,10 @@ class IterMapRewriter : public ExprMutator { PrimExpr mark_offset = it_mark->second.offset; PrimExpr iter_min = mark_offset; PrimExpr iter_max = iter_min + mark->extent; + // the delta of iter_min when it is updated when the lower bound predicate is present + PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0); if (predicate_induced_min.defined()) { + iter_min_delta = predicate_induced_min.value() - iter_min; iter_min = max(predicate_induced_min.value(), iter_min); } if (predicate_induced_max.defined()) { @@ -704,10 +714,12 @@ class IterMapRewriter : public ExprMutator { iter_max = min(predicate_induced_max.value(), iter_max); } } - if (!is_zero(iter_min)) { + // When iter_min_delta is present, we need to normalize the structured form to have minimum of + // 0, and add the delta to the mark_offset + if (!is_zero(iter_min_delta)) { // structured form's offset should be updated flattened_map_.erase(structured_form); - structured_form.CopyOnWrite()->base = -iter_min; + structured_form.CopyOnWrite()->base -= iter_min_delta; mark.CopyOnWrite()->source = structured_form; flattened_map_[structured_form] = flattened_form; } @@ -716,8 +728,9 @@ class IterMapRewriter : public ExprMutator { // we need to note down the flattened form of constrained iterators // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); - expr.CopyOnWrite()->args = Array({split}); - expr.CopyOnWrite()->base = base + iter_min; + IterSumExprNode* normalized_expr = expr.CopyOnWrite(); + normalized_expr->args = Array({split}); + normalized_expr->base = base; return expr; } ErrorLogger(this) << "Could not normalize iterators using the constraints given."; @@ -1089,8 +1102,8 @@ class IterMapRewriter : public ExprMutator { std::vector flattened_iters, grouped_iters; // check if it can be remapped into a fused pattern. - PrimExpr expected_extra_base = 0; - PrimExpr tail_extent = 0; + PrimExpr expected_extra_base = make_const(expr.dtype(), 0); + PrimExpr tail_extent = make_const(expr.dtype(), 0); PrimExpr expected_scale = base_scale; int first_possible_unit_extent_pos = FindFirstPossibleUnitExtentIndex(expr); @@ -1143,8 +1156,9 @@ class IterMapRewriter : public ExprMutator { size_t k = 0; for (; k < expr->args.size(); ++k) { if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) { - if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale)) + if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale)) { break; + } } } if (k == expr->args.size()) { @@ -1201,7 +1215,7 @@ class IterMapRewriter : public ExprMutator { } else { // new iter, form a new mark IterMark mark = IterMark(structured_form, div(expected_scale, base_scale) + tail_extent); - sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0); + sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, expected_extra_base); flattened_map_[structured_form] = flattened_form; return IterSumExpr({IterSplitExpr(mark, base_scale)}, expr->base + expected_extra_base); } diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index cee9922e86fa..63bb79d2b223 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -236,6 +236,7 @@ def test_compound_floormod_two_regression(): def test_predicate(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") # available contraints # upper bound only @@ -269,6 +270,12 @@ def test_predicate(): predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127), ) + assert_iter_sum_pattern( + {x * 64 + y * 4 + z: (16, 16)}, + var_dom([(x, 16), (y, 16), (z, 4)]), + predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, 4 <= x * 16 + y), + ) + # constraint on one fused iter i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32")