From 6a548cd3549b93bdfbf790b5ab4f163dddb8c710 Mon Sep 17 00:00:00 2001 From: Min Chen Date: Thu, 10 Nov 2022 09:34:14 +0000 Subject: [PATCH 1/3] [TIR][Schedule] Fix region_cover checking for cache related primitives --- .../schedule/primitive/cache_read_write.cc | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 2c86c2df2d25..d4c5e74e0e4b 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1078,7 +1078,7 @@ class ReIndexRewriter : public StmtExprMutator { Region region_; }; -void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root) { +void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) { class NotRegionCoverError : public ScheduleError { public: explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {} @@ -1095,12 +1095,16 @@ The region cover property require to hold for every of its child blocks IRModule mod_; Block block_; }; - BlockScope scope = self->GetBlockScope(scope_root); - for (const auto& kv : scope->dst2deps) { - const StmtSRef& consumer_block_sref = kv.first; - if (!self->block_info.at(consumer_block_sref).region_cover) { - const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); - throw NotRegionCoverError(self->mod, GetRef(block)); + + for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) { + const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); + for (const BufferRegion& region : child_block->reads) { + if (region->buffer.same_as(read_buffer)) { + if (!self->block_info.at(child_block_sref).region_cover) { + const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); + throw NotRegionCoverError(self->mod, GetRef(block)); + } + } } } } @@ -1129,7 +1133,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check required region cover for cache_read - CheckRegionCover(self, scope_sref); + CheckRegionCover(self, scope_sref, read_buffer); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 2. Create CacheStageInfo @@ -1281,7 +1285,7 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check 3. Check required region cover for cache_read - CheckRegionCover(self, scope_sref); + CheckRegionCover(self, scope_sref, buffer); // Check 4. Check if target block both read & write target buffer. const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref); @@ -1318,6 +1322,8 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); BlockInfo& block_info_read = self->block_info[result_block_sref]; block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info_read.region_cover = true; + block_info_read.scope->stage_pipeline = false; results_block_sref.push_back(result_block_sref); // Do cache write From 7be4e101bb51af5a7f433486bdd06e4b7905788e Mon Sep 17 00:00:00 2001 From: Min Chen Date: Thu, 10 Nov 2022 09:38:09 +0000 Subject: [PATCH 2/3] [TIR][Schedule] Fix CacheLocDetector for nested SeqStmt --- src/tir/schedule/primitive/cache_read_write.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index d4c5e74e0e4b..b3e0e8f1274e 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -437,13 +437,13 @@ class CacheLocDetector : public StmtVisitor { if (visited_block_ && visited_related_ && loc_pos_ == -1) { // The offset of insert position from the block loc_pos_ = i; - return; + break; } else if (visited_related_) { // If meet the target consumer, stop searching - visited_block_ = visited_block_ || previous_visited_block; - return; + break; } } + visited_block_ = visited_block_ || previous_visited_block; } void VisitStmt_(const BlockNode* block) final { From c2d0efc3e7026b706366a7bb409314cd9be6bd90 Mon Sep 17 00:00:00 2001 From: Min Chen Date: Fri, 11 Nov 2022 07:06:16 +0000 Subject: [PATCH 3/3] Add test case for CacheLocDetector issue. --- .../test_tir_schedule_cache_read_write.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index a237a5b75839..3476ca083056 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -59,6 +59,33 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 +@T.prim_func +def func_nested_seq(b: T.handle, c: T.handle) -> None: + A = T.alloc_buffer((128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j in T.grid(128, 128): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 2.0 + for i, j in T.grid(8, 8): + for x, y in T.grid(16, 16): + with T.block("B0"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = 1.0 + for x, y in T.grid(16, 16): + with T.block("B1"): + vi = T.axis.S(128, i * 16 + x) + vj = T.axis.S(128, j * 16 + y) + B[vi, vj] = A[vi, vj] + B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + + @T.prim_func def access_under_scope(b: T.handle, c: T.handle) -> None: A = T.alloc_buffer((128, 128)) @@ -250,6 +277,47 @@ def inplace_call(data_io: T.Buffer[(64), "int32"]): T.evaluate(T.call_extern("call_impl", data_io.data, dtype="")) +@T.prim_func +def cache_read_nested_seq_target( + B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"] +) -> None: + A = T.alloc_buffer([128, 128], dtype="float32") + A_global = T.alloc_buffer([128, 128], dtype="float32") + for i, j in T.grid(128, 128): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads() + T.writes(A[vi, vj]) + A[vi, vj] = T.float32(2) + for i, j in T.grid(8, 8): + for x, y in T.grid(16, 16): + with T.block("B0"): + vi = T.axis.spatial(128, i * 16 + x) + vj = T.axis.spatial(128, j * 16 + y) + T.reads() + T.writes(B[vi, vj]) + B[vi, vj] = T.float32(1) + for x, y in T.grid(16, 16): + with T.block("B1"): + vi = T.axis.spatial(128, i * 16 + x) + vj = T.axis.spatial(128, j * 16 + y) + T.reads(A[vi, vj], B[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] + B[vi, vj] + for ax0, ax1 in T.grid(128, 128): + with T.block("A_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v0, v1]) + T.writes(A_global[v0, v1]) + A_global[v0, v1] = A[v0, v1] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A_global[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = A_global[vi, vj] * T.float32(2) + + ########## Expected function after cache_read ########## @@ -989,6 +1057,14 @@ def test_cache_inplace(): verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask) +def test_cache_read_nested_seq(use_block_name): + sch = tir.Schedule(func_nested_seq, debug_mask="all") + block_c = "C" if use_block_name else sch.get_block("C") + sch.cache_read(block_c, 0, "global", consumer_blocks=[block_c]) + tvm.ir.assert_structural_equal(cache_read_nested_seq_target, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_nested_seq) + + ########## Testcases for cache_write ##########