From d31a773f71c910e4cc39e48a6383f15bc6feb0d8 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 13 Sep 2022 01:05:36 +0800 Subject: [PATCH 1/3] Fix plan buffer allocation location for loop carried dependencies --- .../analysis/buffer_access_lca_detector.cc | 102 ++++++++++++++-- ..._plan_update_buffer_allocation_location.py | 109 +++++++++++++++++- 2 files changed, 196 insertions(+), 15 deletions(-) diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 7197e1ba83c5..26b290f3142a 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -99,23 +99,32 @@ class LCADetector : public StmtExprVisitor { } ancestor_scopes_.push_back(current_scope); + loop_scope_map_.insert({op->loop_var.get(), current_scope}); StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); + loop_scope_map_.erase(op->loop_var.get()); } - void VisitStmt_(const BlockNode* op) final { + void VisitStmt_(const BlockRealizeNode* op) final { + const BlockNode* block = op->block.get(); int n = ancestor_scopes_.size(); - for (const Buffer& buf : op->alloc_buffers) { + for (const Buffer& buf : block->alloc_buffers) { buffer_var_map_.emplace(buf->data.get(), buf.get()); } const ScopeInfo* parent_scope = ancestor_scopes_.back(); - auto* current_scope = arena_.make(parent_scope, op, n); + auto* current_scope = arena_.make(parent_scope, block, n); ancestor_scopes_.push_back(current_scope); + + // For each buffer the block has accessed, update the buffer's lca + // to the lowest inclusive stmt scope which dominate all loop carried + // dependencies related to the accessed opaque block iter vars. + UpdateDominateScopeOfOpaqueIter(op); + // Update match_buffers - for (const MatchBufferRegion& match_buffer : op->match_buffers) { - UpdateBufferLCA(match_buffer->source->buffer.get()); + for (const MatchBufferRegion& match_buffer : block->match_buffers) { + UpdateBufferLCA(match_buffer->source->buffer.get(), ancestor_scopes_.back()); match_buffers_.insert(match_buffer->buffer.get()); } @@ -123,6 +132,76 @@ class LCADetector : public StmtExprVisitor { ancestor_scopes_.pop_back(); } + void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) { + // map opaque iter var to the scope which dominate all loop carried dependencies. + std::unordered_map itervar_to_dom_scope; + + // function to collect `itervar_to_dom_scope`, the result scope for each block + // iter var should be above all loop scopes the opaque iter var binding relates to. + auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar, + const PrimExpr& binding) { + PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) { + if (const VarNode* loop_var = obj.as()) { + auto it = loop_scope_map_.find(loop_var); + if (it == loop_scope_map_.end()) { + return; + } + const ScopeInfo* scope = it->second->parent_scope_info; + auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get()); + if (dom_scope_it == itervar_to_dom_scope.end()) { + itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope}); + } else if (scope->depth < dom_scope_it->second->depth) { + dom_scope_it->second = scope; + } + } + }); + }; + + // function to update lca scope of the buffer with loop carried dependent buffer accesses. + // the result scope should be above all loop scopes the accessed opaque block iter vars + // relate to, which is record in `itervar_to_dom_scope`. + auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) { + const Buffer& buffer = region->buffer; + const ScopeInfo* scope = ancestor_scopes_.back(); + + auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) { + if (const VarNode* iter_var = obj.as()) { + auto dom_scope_it = itervar_to_dom_scope.find(iter_var); + if (dom_scope_it == itervar_to_dom_scope.end()) { + return; + } + if (dom_scope_it->second->depth < scope->depth) { + scope = dom_scope_it->second; + } + } + }; + + for (const Range& range : region->region) { + PostOrderVisit(range->min, handle_itervar); + PostOrderVisit(range->min + range->extent - 1, handle_itervar); + } + UpdateBufferLCA(buffer.get(), scope); + }; + + // execute collect and update + const Block& block = block_realize->block; + for (size_t i = 0; i < block_realize->iter_values.size(); ++i) { + const IterVar& iter_var = block->iter_vars[i]; + if (iter_var->iter_type != IterVarType::kDataPar && + iter_var->iter_type != IterVarType::kCommReduce) { + do_collect_itervar_scope(iter_var, block_realize->iter_values[i]); + } + } + if (!itervar_to_dom_scope.empty()) { + for (const auto& read : block->reads) { + do_update(read); + } + for (const auto& write : block->writes) { + do_update(write); + } + } + } + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { const auto* iter = op->node.as(); @@ -136,17 +215,18 @@ class LCADetector : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - UpdateBufferLCA(op->buffer.get()); + UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back()); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - UpdateBufferLCA(op->buffer.get()); + UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back()); StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const BufferRealizeNode* op) final { buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get()); + UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back()); StmtExprVisitor::VisitStmt_(op); } @@ -165,16 +245,16 @@ class LCADetector : public StmtExprVisitor { void VisitBufferVar(const VarNode* op) { auto it = buffer_var_map_.find(op); if (it != buffer_var_map_.end()) { - UpdateBufferLCA(it->second); + UpdateBufferLCA(it->second, ancestor_scopes_.back()); } } - void UpdateBufferLCA(const BufferNode* buffer) { + void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) { buffer_var_map_.emplace(buffer->data.get(), buffer); if (match_buffers_.find(buffer) == match_buffers_.end()) { // Ingore buffer created by block match_buffer const ScopeInfo*& lca = buffer_lca_[buffer]; - lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + lca = LowestCommonAncestor(lca, scope); } } @@ -229,6 +309,8 @@ class LCADetector : public StmtExprVisitor { std::unordered_set match_buffers_ = {}; /*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */ std::vector blockidx_scopes_ = {}; + /*! \brief The map from loop var to the corresponding scope. */ + std::unordered_map loop_scope_map_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index c22f5f82ee10..f44b822764c8 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te from tvm.script import tir as T @@ -242,9 +243,107 @@ def test_lower_te(): ) # PlanAndUpdateBufferAllocationLocation should do nothing on TE +def test_loop_carried_dependency(): + """The buffer allocation should be above opaque iter var's loop scopes + such that buffer accesses with loop carried dependencies are covered.""" + + @T.prim_func + def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]): + C = T.alloc_buffer([8, 8, 8], dtype="int32") + for i in T.serial(8): + for j in T.serial(8): + for k in T.serial(8): + with T.block("b0"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] + 1 + for k in T.serial(8): + with T.block("b1"): + vi, vk = T.axis.remap("SS", [i, k]) + vj = T.axis.opaque(8, j) + B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else( + 0 < vj, C[vi, vj - j, vk], 0, dtype="int32" + ) + + @T.prim_func + def after(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]) -> None: + for i in T.serial(8): + with T.block(): + T.reads(A[i, 0:8, 0:8]) + T.writes(B[i, 0:8, 0:8]) + C = T.alloc_buffer([8, 8, 8], dtype="int32") + for j in T.serial(8): + for k in T.serial(8): + with T.block("b0"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] + 1 + for k in T.serial(8): + with T.block("b1"): + vi, vk = T.axis.remap("SS", [i, k]) + vj = T.axis.opaque(8, j) + B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else( + 0 < vj, C[vi, vj - j, vk], 0, dtype="int32" + ) + + _check(before, after) + + +def test_1D_cascade_op_rolling_buffer(): + """The intermediate buffer must be allocated above rolling buffer's rolling loop, + which is marked as opaque in consumer block's iter mappings.""" + + @T.prim_func + def before(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]): + B = T.alloc_buffer((4, 6), "int32") + for c in T.serial(4): + for i in T.serial(0, 2): + for j in T.serial(0, 6): + for k in T.serial(3): + with T.block("P1"): + T.where(i < 1 or j >= 2) + cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k]) + if vk == 0: + B[cc, T.floormod(vi * 4 + vj, 6)] = 0 + B[cc, T.floormod(vi * 4 + vj, 6)] = ( + B[cc, T.floormod(vi * 4 + vj, 6)] + A[cc, vi * 4 + vj + vk] + ) + for j in T.serial(0, 4): + for k in T.serial(3): + with T.block("P2"): + vi = T.axis.opaque(2, i) + cc, vj, vk = T.axis.remap("SSR", [c, j, k]) + if vk == 0: + C[cc, i * 4 + j] = 0 + C[cc, vi * 4 + vj] = ( + C[cc, vi * 4 + vj] + B[cc, T.floormod(vi * 4 + vj + vk, 6)] + ) + + @T.prim_func + def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]): + for c in T.serial(4): + with T.block(): + T.reads(A[c, 0:12], C[c, 0:8]) + T.writes(C[c, 0:15]) + B = T.alloc_buffer([4, 6], dtype="int32") + for i in T.serial(2): + for j, k in T.grid(6, 3): + with T.block("P1"): + T.where(i < 1 or j >= 2) + cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k]) + if vk == 0: + B[cc, (vi * 4 + vj) % 6] = 0 + B[cc, (vi * 4 + vj) % 6] = ( + B[cc, (vi * 4 + vj) % 6] + A[cc, vi * 4 + vj + vk] + ) + for j, k in T.grid(4, 3): + with T.block("P2"): + vi = T.axis.opaque(2, i) + cc, vj, vk = T.axis.remap("SSR", [c, j, k]) + if vk == 0: + C[cc, i * 4 + j] = 0 + C[cc, vi * 4 + vj] = C[cc, vi * 4 + vj] + B[cc, (vi * 4 + vj + vk) % 6] + + _check(before, after) + + if __name__ == "__main__": - test_elementwise() - test_locate_buffer_allocation() - test_match_buffer_allocation() - test_opaque_access() - test_lower_te() + tvm.testing.main() From dc143adf15ad00f8600804aa8e925e3d65e0e251 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 14 Sep 2022 23:32:56 +0800 Subject: [PATCH 2/3] fix testcase region annotation issue --- src/tir/analysis/buffer_access_lca_detector.cc | 12 ++++++++---- ...ansform_plan_update_buffer_allocation_location.py | 6 +++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 26b290f3142a..64d10fae2ff1 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -117,9 +117,9 @@ class LCADetector : public StmtExprVisitor { ancestor_scopes_.push_back(current_scope); - // For each buffer the block has accessed, update the buffer's lca - // to the lowest inclusive stmt scope which dominate all loop carried - // dependencies related to the accessed opaque block iter vars. + // For each accessed buffer of the block, update the buffer's lca to + // the lowest inclusive stmt position, which should dominate all loops + // related to the accessed opaque block iter vars in buffer indices. UpdateDominateScopeOfOpaqueIter(op); // Update match_buffers @@ -147,6 +147,7 @@ class LCADetector : public StmtExprVisitor { return; } const ScopeInfo* scope = it->second->parent_scope_info; + // find the highest loop scope the iter var binding has related to. auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get()); if (dom_scope_it == itervar_to_dom_scope.end()) { itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope}); @@ -170,12 +171,15 @@ class LCADetector : public StmtExprVisitor { if (dom_scope_it == itervar_to_dom_scope.end()) { return; } + // find the highest loop scope the accessed buffer index has + // loop carried dependencies to (via opaque iter var binding). if (dom_scope_it->second->depth < scope->depth) { scope = dom_scope_it->second; } } }; + // visit region min and max to find the lowest legal lca scope for (const Range& range : region->region) { PostOrderVisit(range->min, handle_itervar); PostOrderVisit(range->min + range->extent - 1, handle_itervar); @@ -183,7 +187,7 @@ class LCADetector : public StmtExprVisitor { UpdateBufferLCA(buffer.get(), scope); }; - // execute collect and update + // do collect and update const Block& block = block_realize->block; for (size_t i = 0; i < block_realize->iter_values.size(); ++i) { const IterVar& iter_var = block->iter_vars[i]; diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index f44b822764c8..32d6ebdddcee 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -312,7 +312,7 @@ def before(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]): vi = T.axis.opaque(2, i) cc, vj, vk = T.axis.remap("SSR", [c, j, k]) if vk == 0: - C[cc, i * 4 + j] = 0 + C[cc, vi * 4 + vj] = 0 C[cc, vi * 4 + vj] = ( C[cc, vi * 4 + vj] + B[cc, T.floormod(vi * 4 + vj + vk, 6)] ) @@ -322,7 +322,7 @@ def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]): for c in T.serial(4): with T.block(): T.reads(A[c, 0:12], C[c, 0:8]) - T.writes(C[c, 0:15]) + T.writes(C[c, 0:8]) B = T.alloc_buffer([4, 6], dtype="int32") for i in T.serial(2): for j, k in T.grid(6, 3): @@ -339,7 +339,7 @@ def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]): vi = T.axis.opaque(2, i) cc, vj, vk = T.axis.remap("SSR", [c, j, k]) if vk == 0: - C[cc, i * 4 + j] = 0 + C[cc, vi * 4 + vj] = 0 C[cc, vi * 4 + vj] = C[cc, vi * 4 + vj] + B[cc, (vi * 4 + vj + vk) % 6] _check(before, after) From fc5efe970cd973c358df3e06ab8a314ae3134793 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 25 Sep 2022 21:55:57 +0800 Subject: [PATCH 3/3] fix typo in ut --- ...st_tir_transform_plan_update_buffer_allocation_location.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 32d6ebdddcee..34d82f86a422 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -261,7 +261,7 @@ def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]): vi, vk = T.axis.remap("SS", [i, k]) vj = T.axis.opaque(8, j) B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else( - 0 < vj, C[vi, vj - j, vk], 0, dtype="int32" + 0 < vj, C[vi, vj - 1, vk], 0, dtype="int32" ) @T.prim_func @@ -281,7 +281,7 @@ def after(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]) -> N vi, vk = T.axis.remap("SS", [i, k]) vj = T.axis.opaque(8, j) B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else( - 0 < vj, C[vi, vj - j, vk], 0, dtype="int32" + 0 < vj, C[vi, vj - 1, vk], 0, dtype="int32" ) _check(before, after)