Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 96 additions & 10 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,30 +99,113 @@ class LCADetector : public StmtExprVisitor {
}

ancestor_scopes_.push_back(current_scope);
loop_scope_map_.insert({op->loop_var.get(), current_scope});
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
loop_scope_map_.erase(op->loop_var.get());
}

void VisitStmt_(const BlockNode* op) final {
void VisitStmt_(const BlockRealizeNode* op) final {
const BlockNode* block = op->block.get();
int n = ancestor_scopes_.size();
for (const Buffer& buf : op->alloc_buffers) {
for (const Buffer& buf : block->alloc_buffers) {
buffer_var_map_.emplace(buf->data.get(), buf.get());
}

const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, block, n);

ancestor_scopes_.push_back(current_scope);

// For each accessed buffer of the block, update the buffer's lca to
// the lowest inclusive stmt position, which should dominate all loops
// related to the accessed opaque block iter vars in buffer indices.
UpdateDominateScopeOfOpaqueIter(op);

// Update match_buffers
for (const MatchBufferRegion& match_buffer : op->match_buffers) {
UpdateBufferLCA(match_buffer->source->buffer.get());
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
UpdateBufferLCA(match_buffer->source->buffer.get(), ancestor_scopes_.back());
match_buffers_.insert(match_buffer->buffer.get());
}

StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}

void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) {
// map opaque iter var to the scope which dominate all loop carried dependencies.
std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;

// function to collect `itervar_to_dom_scope`, the result scope for each block
// iter var should be above all loop scopes the opaque iter var binding relates to.
auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar,
const PrimExpr& binding) {
PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) {
if (const VarNode* loop_var = obj.as<VarNode>()) {
auto it = loop_scope_map_.find(loop_var);
if (it == loop_scope_map_.end()) {
return;
}
const ScopeInfo* scope = it->second->parent_scope_info;
// find the highest loop scope the iter var binding has related to.
auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get());
if (dom_scope_it == itervar_to_dom_scope.end()) {
itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope});
} else if (scope->depth < dom_scope_it->second->depth) {
dom_scope_it->second = scope;
}
}
});
};

// function to update lca scope of the buffer with loop carried dependent buffer accesses.
// the result scope should be above all loop scopes the accessed opaque block iter vars
// relate to, which is record in `itervar_to_dom_scope`.
auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) {
const Buffer& buffer = region->buffer;
const ScopeInfo* scope = ancestor_scopes_.back();

auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) {
if (const VarNode* iter_var = obj.as<VarNode>()) {
auto dom_scope_it = itervar_to_dom_scope.find(iter_var);
if (dom_scope_it == itervar_to_dom_scope.end()) {
return;
}
// find the highest loop scope the accessed buffer index has
// loop carried dependencies to (via opaque iter var binding).
if (dom_scope_it->second->depth < scope->depth) {
scope = dom_scope_it->second;
}
}
};

// visit region min and max to find the lowest legal lca scope
for (const Range& range : region->region) {
PostOrderVisit(range->min, handle_itervar);
PostOrderVisit(range->min + range->extent - 1, handle_itervar);
}
UpdateBufferLCA(buffer.get(), scope);
};

// do collect and update
const Block& block = block_realize->block;
for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
const IterVar& iter_var = block->iter_vars[i];
if (iter_var->iter_type != IterVarType::kDataPar &&
iter_var->iter_type != IterVarType::kCommReduce) {
do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
}
}
if (!itervar_to_dom_scope.empty()) {
for (const auto& read : block->reads) {
do_update(read);
}
for (const auto& write : block->writes) {
do_update(write);
}
}
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const auto* iter = op->node.as<IterVarNode>();
Expand All @@ -136,17 +219,18 @@ class LCADetector : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) final {
UpdateBufferLCA(op->buffer.get());
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
UpdateBufferLCA(op->buffer.get());
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const BufferRealizeNode* op) final {
buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get());
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
StmtExprVisitor::VisitStmt_(op);
}

Expand All @@ -165,16 +249,16 @@ class LCADetector : public StmtExprVisitor {
void VisitBufferVar(const VarNode* op) {
auto it = buffer_var_map_.find(op);
if (it != buffer_var_map_.end()) {
UpdateBufferLCA(it->second);
UpdateBufferLCA(it->second, ancestor_scopes_.back());
}
}

void UpdateBufferLCA(const BufferNode* buffer) {
void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) {
buffer_var_map_.emplace(buffer->data.get(), buffer);
if (match_buffers_.find(buffer) == match_buffers_.end()) {
// Ingore buffer created by block match_buffer
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
lca = LowestCommonAncestor(lca, scope);
}
}

Expand Down Expand Up @@ -229,6 +313,8 @@ class LCADetector : public StmtExprVisitor {
std::unordered_set<const BufferNode*> match_buffers_ = {};
/*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
std::vector<const ScopeInfo*> blockidx_scopes_ = {};
/*! \brief The map from loop var to the corresponding scope. */
std::unordered_map<const VarNode*, const ScopeInfo*> loop_scope_map_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te
from tvm.script import tir as T

Expand Down Expand Up @@ -242,9 +243,107 @@ def test_lower_te():
) # PlanAndUpdateBufferAllocationLocation should do nothing on TE


def test_loop_carried_dependency():
"""The buffer allocation should be above opaque iter var's loop scopes
such that buffer accesses with loop carried dependencies are covered."""

@T.prim_func
def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]):
C = T.alloc_buffer([8, 8, 8], dtype="int32")
for i in T.serial(8):
for j in T.serial(8):
for k in T.serial(8):
with T.block("b0"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] + 1
for k in T.serial(8):
with T.block("b1"):
vi, vk = T.axis.remap("SS", [i, k])
vj = T.axis.opaque(8, j)
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
)

@T.prim_func
def after(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]) -> None:
for i in T.serial(8):
with T.block():
T.reads(A[i, 0:8, 0:8])
T.writes(B[i, 0:8, 0:8])
C = T.alloc_buffer([8, 8, 8], dtype="int32")
for j in T.serial(8):
for k in T.serial(8):
with T.block("b0"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] + 1
for k in T.serial(8):
with T.block("b1"):
vi, vk = T.axis.remap("SS", [i, k])
vj = T.axis.opaque(8, j)
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
)

_check(before, after)


def test_1D_cascade_op_rolling_buffer():
"""The intermediate buffer must be allocated above rolling buffer's rolling loop,
which is marked as opaque in consumer block's iter mappings."""

@T.prim_func
def before(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
B = T.alloc_buffer((4, 6), "int32")
for c in T.serial(4):
for i in T.serial(0, 2):
for j in T.serial(0, 6):
for k in T.serial(3):
with T.block("P1"):
T.where(i < 1 or j >= 2)
cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k])
if vk == 0:
B[cc, T.floormod(vi * 4 + vj, 6)] = 0
B[cc, T.floormod(vi * 4 + vj, 6)] = (
B[cc, T.floormod(vi * 4 + vj, 6)] + A[cc, vi * 4 + vj + vk]
)
for j in T.serial(0, 4):
for k in T.serial(3):
with T.block("P2"):
vi = T.axis.opaque(2, i)
cc, vj, vk = T.axis.remap("SSR", [c, j, k])
if vk == 0:
C[cc, vi * 4 + vj] = 0
C[cc, vi * 4 + vj] = (
C[cc, vi * 4 + vj] + B[cc, T.floormod(vi * 4 + vj + vk, 6)]
)

@T.prim_func
def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
for c in T.serial(4):
with T.block():
T.reads(A[c, 0:12], C[c, 0:8])
T.writes(C[c, 0:8])
B = T.alloc_buffer([4, 6], dtype="int32")
for i in T.serial(2):
for j, k in T.grid(6, 3):
with T.block("P1"):
T.where(i < 1 or j >= 2)
cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k])
if vk == 0:
B[cc, (vi * 4 + vj) % 6] = 0
B[cc, (vi * 4 + vj) % 6] = (
B[cc, (vi * 4 + vj) % 6] + A[cc, vi * 4 + vj + vk]
)
for j, k in T.grid(4, 3):
with T.block("P2"):
vi = T.axis.opaque(2, i)
cc, vj, vk = T.axis.remap("SSR", [c, j, k])
if vk == 0:
C[cc, vi * 4 + vj] = 0
C[cc, vi * 4 + vj] = C[cc, vi * 4 + vj] + B[cc, (vi * 4 + vj + vk) % 6]

_check(before, after)


if __name__ == "__main__":
test_elementwise()
test_locate_buffer_allocation()
test_match_buffer_allocation()
test_opaque_access()
test_lower_te()
tvm.testing.main()