diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 4baff106096c..fddf73da015b 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -64,6 +64,21 @@ class DecomposeReductionBlockReplacer : public StmtMutator { ObjectPtr p_new_block = CopyOnWrite(block); p_new_block->name_hint = p_new_block->name_hint + "_update"; p_new_block->init = NullOpt; + // Add write regions back to read regions in update block. + Array new_reads; + std::unordered_set read_bufs; + for (const BufferRegion& read_access : block->reads) { + read_bufs.insert(read_access->buffer.get()); + } + for (const BufferRegion& write_access : block->writes) { + if (read_bufs.find(write_access->buffer.get()) == read_bufs.end()) { + new_reads.push_back(write_access); + } + } + for (const BufferRegion& read_access : block->reads) { + new_reads.push_back(read_access); + } + p_new_block->reads = new_reads; new_reduction_block_ = Block(p_new_block); return new_reduction_block_; } else { @@ -284,22 +299,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, /*body=*/body); } body = Substitute(body, loop_var_map); - // Step 6. Add write regions back to read regions in update block. - Array new_reads; - std::unordered_set read_bufs; - for (const BufferRegion& read_access : block->reads) { - read_bufs.insert(read_access->buffer.get()); - } - for (const BufferRegion& write_access : block->writes) { - if (read_bufs.find(write_access->buffer.get()) == read_bufs.end()) { - new_reads.push_back(write_access); - } - } - for (const BufferRegion& read_access : block->reads) { - new_reads.push_back(read_access); - } - (const_cast(block))->reads = std::move(new_reads); - // Step 7. Mutate IR + // Step 6. Mutate IR const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref); Block new_scope_root{nullptr}; Block new_reduction_block{nullptr}; diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 5ad366b2fa02..4be8ebc2c296 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -282,5 +282,17 @@ def test_reduction_decompose_with_different_for_kind(): verify_trace_roundtrip(s, mod=colsum_with_vectorization) +def test_decompose_reduction_ref_hash_check(): + mod = tvm.IRModule.from_expr(matmul) + mod_bak = mod + hash_before = tvm.ir.structural_hash(mod_bak) + s = tir.Schedule(mod["main"], debug_mask="all") + C = s.get_block("update") + i, j, k = s.get_loops(C) + s.decompose_reduction(C, k) + hash_after = tvm.ir.structural_hash(mod_bak) + assert hash_before == hash_after + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))