diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 67b9ffffe21f..9f45317cba11 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -567,6 +567,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ", "; p->Print(s); } + p->stream << ')'; }); // Sub-class RewriteSimplifier::Impl to take benefit of diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 732045384a95..ccdb952d2d42 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -84,6 +84,9 @@ RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const } } ConstIntBound dbound = analyzer_->const_int_bound(diff); + if (dbound->min_value == val && dbound->max_value == val) { + return kEQ; + } if (dbound->min_value > val) { return kGT; } @@ -819,6 +822,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), floordiv(x, floordiv(c2, c1)), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveEqual(floordiv(y.Eval() + z.Eval(), c1.Eval()), 0)); TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -916,6 +923,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + analyzer_->CanProveLess(y.Eval(), c1.Eval()->value)); + TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index b1919f6eeb94..e07bdba02046 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -504,6 +504,11 @@ def test_floordiv_index_simplify(): ck.verify(fld(y + x * z, z), fld(y, z) + x) ck.verify(fld(y + z * x, z), fld(y, z) + x) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 31), override=True) + ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 3), override=True) + ck.verify(fld(x * 32 + y, 64), fld(x, 2)) + ck.verify(fld(x * 128 + y * 4 + z, 512), fld(x, 4)) + def test_mod_index_simplify(): ck = RewriteChecker() @@ -559,6 +564,9 @@ def test_floormod_index_simplify(): ck.verify(flm(x + (-10), 2), flm(x, 2)) ck.verify(flm(x + y * (-10), 2), flm(x, 2)) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 31), override=True) + ck.verify(flm(x * 32 + y, 64), flm(x, 2) * 32 + y) + def test_min_index_simplify(): ck = RewriteChecker() diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py index 7b6ef5256ae9..db0446b08044 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -315,7 +315,7 @@ def _create_schedule(): 25.0, 16.000022888183594, 15.000043869018555, - 10.001408576965332, + 10.001408194392809, 0.0, ], rtol=1e-5, @@ -951,8 +951,8 @@ def _create_schedule(): 0.0, 0.0, 0.0, - 22.00000034396526, - 22.00000034396526, + 21.584962959341485, + 21.584962959341485, 21.000000687930438, 0.0, 0.0, @@ -1032,7 +1032,7 @@ def _create_schedule(): 0.0, 0.0, 3.169925001442312, - 10.001408194392809, + 9.61654884377899, 8.005624549193879, 14.000088052430122, 1.584962500721156, diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py index 7f60c95164a8..fb1fb72eb82c 100644 --- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -89,12 +89,12 @@ class After_simplified: def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) - T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) - T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") + T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) + T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) + T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # body T.launch_thread(blockIdx_x, 64) conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") @@ -107,7 +107,7 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): - weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] + weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4):