Skip to content

Commit 5fba48d

Browse files
committed
[TIR] Fix reverse_compute_at for trivial region with trivial block var
1 parent eae836c commit 5fba48d

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

src/arith/interval_set.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@ class IntervalSetNode : public IntSetNode {
5959
/*! \return Whether the interval has lower bound. */
6060
bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); }
6161
/*! \return Whether the interval is a single point. */
62-
bool IsSinglePoint() const { return min_value.same_as(max_value); }
62+
bool IsSinglePoint() const {
63+
if (min_value.same_as(max_value)) {
64+
return true;
65+
}
66+
Analyzer analyzer;
67+
return analyzer.CanProveEqual(min_value, max_value);
68+
}
6369
/*! \return whether interval represent nothing */
6470
bool IsEmpty() const {
6571
// during computations, either extreme could occur.

src/tir/schedule/primitive/compute_at.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,12 @@ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& re
428428
const arith::IntSet& required_bound,
429429
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
430430
arith::Analyzer* analyzer) {
431+
if (provided.IsSinglePoint() && is_const_int(provided.min())) {
432+
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
433+
ICHECK(required_bound.IsSinglePoint() &&
434+
analyzer->CanProveEqual(provided.min(), required_bound.min()));
435+
return;
436+
}
431437
auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
432438
auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer);
433439
const Var& var = var_with_dom.first;

tests/python/unittest/test_tir_schedule_compute_at.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,40 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
582582
C[vi, vj] = B[vi, vj] + 1.0
583583

584584

585+
@T.prim_func
586+
def tiled_trivial_binding(a: T.handle, c: T.handle) -> None:
587+
A = T.match_buffer(a, [1, 128, 128], "float32")
588+
B = T.alloc_buffer([1, 128, 128], "float32")
589+
C = T.match_buffer(c, [1, 128, 128], "float32")
590+
for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
591+
with T.block("B"):
592+
vi = T.axis.S(128, i_0 * 16 + i_1)
593+
vj = T.axis.S(128, j_0 * 16 + j_1)
594+
B[0, vi, vj] = A[0, vi, vj] * 2.0
595+
for i, j in T.grid(128, 128):
596+
with T.block("C"):
597+
vi, vj = T.axis.remap("SS", [i, j])
598+
C[0, vi, vj] = B[0, vi, vj] + 1.0
599+
600+
601+
@T.prim_func
602+
def tiled_trivial_binding_after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
603+
A = T.match_buffer(a, [1, 128, 128], "float32")
604+
B = T.alloc_buffer([1, 128, 128], "float32")
605+
C = T.match_buffer(c, [1, 128, 128], "float32")
606+
for i_0, j_0, i_1 in T.grid(8, 8, 16):
607+
for j_1 in T.serial(0, 16):
608+
with T.block("B"):
609+
vi = T.axis.S(128, i_0 * 16 + i_1)
610+
vj = T.axis.S(128, j_0 * 16 + j_1)
611+
B[0, vi, vj] = A[0, vi, vj] * 2.0
612+
for j_1 in T.serial(0, 16):
613+
with T.block("C"):
614+
vi = T.axis.S(128, i_0 * 16 + i_1)
615+
vj = T.axis.S(128, j_0 * 16 + j_1)
616+
C[0, vi, vj] = B[0, vi, vj] + 1.0
617+
618+
585619
@T.prim_func
586620
def factorized(a: T.handle, b: T.handle) -> None:
587621
A = T.match_buffer(a, [16, 16, 16], "float32")
@@ -1149,6 +1183,15 @@ def test_reverse_compute_at_tiled():
11491183
verify_trace_roundtrip(sch=sch, mod=tiled)
11501184

11511185

1186+
def test_reverse_compute_at_tiled_trivial_binding():
1187+
sch = tir.Schedule(tiled_trivial_binding, debug_mask="all")
1188+
block = sch.get_block("C")
1189+
_, _, loop, _ = sch.get_loops(sch.get_block("B"))
1190+
sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
1191+
tvm.ir.assert_structural_equal(tiled_trivial_binding_after_reverse_compute_at, sch.mod["main"])
1192+
verify_trace_roundtrip(sch=sch, mod=tiled_trivial_binding)
1193+
1194+
11521195
def test_reverse_compute_at_blockized_2():
11531196
sch = tir.Schedule(blockized_2, debug_mask="all")
11541197
block = sch.get_block("C")

0 commit comments

Comments
 (0)