From b7a004d9d68c075a596029ed6ec0c92e3c787bd3 Mon Sep 17 00:00:00 2001 From: liujiaqiang Date: Tue, 10 Sep 2024 15:57:24 +0800 Subject: [PATCH] [FIX] fix bug when normalize iter with different lower bounds If an iter has been normalized with a lower bound, and then try to normalize with a new lower bound, the iter_min need to be updated only when the new lower bound is smaller than the original one. --- src/arith/iter_affine_map.cc | 2 +- .../arith/test_arith_iter_affine_map.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 77b20fcdf203..d24c278f1048 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -696,7 +696,7 @@ class IterMapRewriter : public ExprMutator { // the delta of iter_min when it is updated when the lower bound predicate is present PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0); if (predicate_induced_min.defined()) { - iter_min_delta = predicate_induced_min.value() - iter_min; + iter_min_delta = max(predicate_induced_min.value(), iter_min) - iter_min; iter_min = max(predicate_induced_min.value(), iter_min); } if (predicate_induced_max.defined()) { diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index f0e6f05adfad..f34dce5c86fd 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -346,6 +346,27 @@ def test_predicate(): predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), ) + # constraint with differnent lower bound + assert_iter_sum_pattern( + { + (i * 16 + j) // 23 * 8 + + (i * 16 + j) % 23 + - 15: ( + 64, + 0, + 1, + (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tir.IntImm("int32", -15)), + ) + }, + var_dom([(i, 12), (j, 16)]), + predicate=tvm.tir.And( + tvm.tir.And( + i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 16 + j) % 23) + ), + tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23), + ), + ) + # constraint on many disjoint fused iters, case 1 # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1)