diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 82102749609b..45d0c81050d1 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -680,6 +680,20 @@ void CalculateProvidedRequiredRegions( /******** Main Implementation ********/ +void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, + arith::Analyzer* analyzer) { + while (sref->parent != nullptr) { + sref = sref->parent; + } + const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr); + for (const auto& kv : f->buffer_map) { + const Buffer& buffer = kv.second; + for (const PrimExpr& e : buffer->shape) { + analyzer->MarkGlobalNonNegValue(e); + } + } +} + template void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, @@ -692,6 +706,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); Block scope_root = GetRef(scope_root_sref->StmtAs()); + AddShapeVarBounds(self, scope_root_sref.get(), analyzer); BlockScope scope = self->GetBlockScope(scope_root_sref); Array producer_srefs = GetProducers(block_sref, scope); Array consumer_srefs = GetConsumers(block_sref, scope); diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index a1b4cf155949..2e44776a0fdc 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1823,5 +1823,74 @@ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")) verify_trace_roundtrip(sch=sch, mod=before) +def test_shape_var_as_bound(): + # fmt: off + @T.prim_func + def before(a: T.handle, b: T.handle, c: T.handle): + n = T.int32() + A = T.match_buffer(a, (32, 1, 128)) + B = T.match_buffer(b, (32, n, 128)) + C = T.match_buffer(c, (32, 1, n)) + # with T.block("root"): + C_rf = T.alloc_buffer((128, 32, 1, n)) + for ax0_ax1_fused, ax2_fused_1, ax2_fused_0 in T.grid(n * 32, 128, 1): + with T.block("NT_matmul_rf"): + vax2_fused_1 = T.axis.spatial(128, ax2_fused_1) + v0 = T.axis.spatial(32, ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + vax2_fused_0 = T.axis.reduce(1, ax2_fused_0) + T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1]) + T.writes(C_rf[vax2_fused_1, v0, 0, v1]) + with T.init(): + C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0) + C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1] + for ax0_ax1_fused, ax2_fused_1 in T.grid(n * 32, 128): + with T.block("NT_matmul"): + vax2_fused_1 = T.axis.reduce(128, ax2_fused_1) + v0 = T.axis.spatial(32, ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + T.reads(C_rf[vax2_fused_1, v0, 0, v1]) + T.writes(C[v0, 0, v1]) + with T.init(): + C[v0, 0, v1] = T.float32(0) + C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1] + + @T.prim_func + def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): + n = T.int32() + B = T.match_buffer(b, (32, n, 128)) + C = T.match_buffer(c, (32, 1, n)) + # with T.block("root"): + C_rf = T.alloc_buffer((128, 32, 1, n)) + for ax0_ax1_fused in range(n * 32): + for ax2_fused_1, ax2_fused_0 in T.grid(128, 1): + with T.block("NT_matmul_rf"): + vax2_fused_1 = T.axis.spatial(128, ax2_fused_1) + v0 = T.axis.spatial(32, ax0_ax1_fused // n) + v1 = T.axis.spatial(n, ax0_ax1_fused % n) + vax2_fused_0 = T.axis.reduce(1, ax2_fused_0) + T.reads(A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1], B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1]) + T.writes(C_rf[vax2_fused_1, v0, 0, v1]) + with T.init(): + C_rf[vax2_fused_1, v0, 0, v1] = T.float32(0) + C_rf[vax2_fused_1, v0, 0, v1] = C_rf[vax2_fused_1, v0, 0, v1] + A[v0, 0, vax2_fused_0 * 128 + vax2_fused_1] * B[v0, v1, vax2_fused_0 * 128 + vax2_fused_1] + for ax0, ax1, ax2 in T.grid(128, 1, 1): + with T.block("NT_matmul"): + vax2_fused_1 = T.axis.reduce(128, ax0) + v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax1) + v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax2) + T.reads(C_rf[vax2_fused_1, v0, 0, v1]) + T.writes(C[v0, 0, v1]) + with T.init(): + C[v0, 0, v1] = T.float32(0) + C[v0, 0, v1] = C[v0, 0, v1] + C_rf[vax2_fused_1, v0, 0, v1] + # fmt: on + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("NT_matmul") + loop, _, _ = sch.get_loops(sch.get_block("NT_matmul_rf")) + sch.reverse_compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(sch.mod["main"], expected, True) + + if __name__ == "__main__": tvm.testing.main()