Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ", ";
p->Print(s);
}
p->stream << ')';
});

// Sub-class RewriteSimplifier::Impl to take benefit of
Expand Down
12 changes: 12 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const
}
}
ConstIntBound dbound = analyzer_->const_int_bound(diff);
if (dbound->min_value == val && dbound->max_value == val) {
return kEQ;
}
if (dbound->min_value > val) {
return kGT;
}
Expand Down Expand Up @@ -819,6 +822,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
// Rules involving 3-operands.
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval() + z.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
Expand Down Expand Up @@ -916,6 +923,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y,
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
analyzer_->CanProveLess(y.Eval(), c1.Eval()->value));

TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@ def test_floordiv_index_simplify():
ck.verify(fld(y + x * z, z), fld(y, z) + x)
ck.verify(fld(y + z * x, z), fld(y, z) + x)

ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 31), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(0, 3), override=True)
ck.verify(fld(x * 32 + y, 64), fld(x, 2))
ck.verify(fld(x * 128 + y * 4 + z, 512), fld(x, 4))


def test_mod_index_simplify():
ck = RewriteChecker()
Expand Down Expand Up @@ -559,6 +564,9 @@ def test_floormod_index_simplify():
ck.verify(flm(x + (-10), 2), flm(x, 2))
ck.verify(flm(x + y * (-10), 2), flm(x, 2))

ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 31), override=True)
ck.verify(flm(x * 32 + y, 64), flm(x, 2) * 32 + y)


def test_min_index_simplify():
ck = RewriteChecker()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _create_schedule():
25.0,
16.000022888183594,
15.000043869018555,
10.001408576965332,
10.001408194392809,
0.0,
],
rtol=1e-5,
Expand Down Expand Up @@ -951,8 +951,8 @@ def _create_schedule():
0.0,
0.0,
0.0,
22.00000034396526,
22.00000034396526,
21.584962959341485,
21.584962959341485,
21.000000687930438,
0.0,
0.0,
Expand Down Expand Up @@ -1032,7 +1032,7 @@ def _create_schedule():
0.0,
0.0,
3.169925001442312,
10.001408194392809,
9.61654884377899,
8.005624549193879,
14.000088052430122,
1.584962500721156,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ class After_simplified:
def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data)
T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data)
T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data)
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
blockIdx_x = T.env_thread("blockIdx.x")
T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data)
T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data)
T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data)
# body
T.launch_thread(blockIdx_x, 64)
conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local")
Expand All @@ -107,7 +107,7 @@ def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "flo
for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32")
for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024]
for ax1, ax2 in T.grid(2, 4):
Expand Down