diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 4174a6699e06..a2b45d407ddf 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -417,7 +417,13 @@ class CacheLocDetector : public StmtVisitor { info->loc_pos = detector.loc_pos_; } else { info->loc_sref = scope_sref; - const auto* body = scope_sref->StmtAs()->body.as(); + + auto block_body = scope_sref->StmtAs()->body; + // Find the SeqStmtNode within (potentially nested) AllocateConstNodes + while (block_body->IsInstance()) { + block_body = block_body.as()->body; + } + const auto* body = block_body.as(); info->loc_pos = body == nullptr ? 1 : body->size(); } } diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 28c9a13700bf..6a75057e72ff 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -1005,6 +1005,67 @@ def block_predicate_cache_write_output_buf() -> None: use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) +@T.prim_func +def cache_write_allocate_const( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float16"] +): + B = T.alloc_buffer([128, 128], dtype="float32") + const = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + const_1 = T.buffer_decl([8], dtype="float32", data=const) + const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + const_2 = T.buffer_decl([8], dtype="float32", data=const) + for i, j in T.grid(128, 128): + for x in range(8): + with T.block("B"): + vi, vj, vx = T.axis.remap("SSS", [i, j, x]) + T.reads(A[vi, vj], const_1[vx], const_2[vx]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * const_1[vx] + const_2[vx] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def cache_write_allocate_const_output( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float16"] +): + B = T.alloc_buffer([128, 128], dtype="float32") + A_global = T.alloc_buffer([128, 128], dtype="float32") + C_global = T.alloc_buffer([128, 128], dtype="float16") + const_2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + const_1 = T.buffer_decl([8], dtype="float32", data=const_2) + const_2_1 = T.buffer_decl([8], dtype="float32", data=const_2) + const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + for ax0, ax1 in T.grid(128, 128): + with T.block("A_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v0, v1]) + T.writes(A_global[v0, v1]) + A_global[v0, v1] = A[v0, v1] + for i, j, x in T.grid(128, 128, 8): + with T.block("B"): + vi, vj, vx = T.axis.remap("SSS", [i, j, x]) + T.reads(A_global[vi, vj], const_1[vx], const_2_1[vx]) + T.writes(B[vi, vj]) + B[vi, vj] = A_global[vi, vj] * const_1[vx] + const_2_1[vx] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C_global[vi, vj]) + C_global[vi, vj] = B[vi, vj] + T.float32(1) + for ax0, ax1 in T.grid(128, 128): + with T.block("C_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(C_global[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_global[v0, v1] + + def test_cache_read_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") @@ -1265,5 +1326,15 @@ def test_cache_write_fail_invalid_storage_scope(use_block_name): sch.cache_write(block_b, 0, "test_scope") +def test_cache_write_allocate_const(): + sch = tir.Schedule(cache_write_allocate_const) + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.cache_read(block_b, 0, "global") + sch.cache_write(block_c, 0, "global") + tvm.ir.assert_structural_equal(cache_write_allocate_const_output, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const) + + if __name__ == "__main__": tvm.testing.main()