Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 54 additions & 12 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,29 @@ class BufferStrideLegalize : public StmtExprMutator {
// be simplified in the future by having AllocateNode hold a buffer,
// rather than a buffer_var.
Stmt VisitStmt_(const AllocateNode* op) final {
allocate_node_var_.insert(op->buffer_var.get());
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AllocateConstNode* op) final {
allocate_node_var_.insert(op->buffer_var.get());
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const LetStmtNode* op) final {
if (op->var.dtype().is_handle()) {
buffer_var_defines_.insert(op->var.get());
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr VisitExpr_(const LetNode* op) final {
if (op->var.dtype().is_handle()) {
buffer_var_defines_.insert(op->var.get());
}
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const BufferRealizeNode* op) final {
Buffer key = op->buffer;
Buffer with_strides = WithStrides(op->buffer);
Expand Down Expand Up @@ -575,7 +589,7 @@ class BufferStrideLegalize : public StmtExprMutator {
template <typename Node>
Node VisitBufferAccess(Node node) {
auto alloc_key = node->buffer->data.get();
if (!buf_map_.count(node->buffer) && allocate_node_var_.count(alloc_key)) {
if (!buf_map_.count(node->buffer) && buffer_var_defines_.count(alloc_key)) {
BufferEntry entry;
entry.remap_to = WithStrides(node->buffer);
entry.in_scope = true;
Expand Down Expand Up @@ -615,7 +629,7 @@ class BufferStrideLegalize : public StmtExprMutator {

// Set of vars that have occurred in an AllocateNode, but haven't
// yet occurred in a BufferLoad/BufferStore.
std::unordered_set<const VarNode*> allocate_node_var_;
std::unordered_set<const VarNode*> buffer_var_defines_;

IRVisitorWithAnalyzer* bound_analyzer_;
};
Expand Down Expand Up @@ -918,15 +932,29 @@ class BufferBindUnwrapper : public StmtExprMutator {
// be simplified in the future by having AllocateNode hold a buffer,
// rather than a buffer_var.
Stmt VisitStmt_(const AllocateNode* op) final {
allocate_node_var_.insert(op->buffer_var.get());
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AllocateConstNode* op) final {
allocate_node_var_.insert(op->buffer_var.get());
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const LetStmtNode* op) final {
if (op->var.dtype().is_handle()) {
buffer_var_defines_.insert(op->var.get());
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr VisitExpr_(const LetNode* op) final {
if (op->var.dtype().is_handle()) {
buffer_var_defines_.insert(op->var.get());
}
return StmtExprMutator::VisitExpr_(op);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<BufferLoadNode>();
Expand Down Expand Up @@ -1118,7 +1146,7 @@ class BufferBindUnwrapper : public StmtExprMutator {

const BufferEntry& GetBufferEntry(Buffer buffer) {
auto alloc_key = buffer->data.get();
if (!buf_map_.count(buffer.get()) && allocate_node_var_.count(alloc_key)) {
if (!buf_map_.count(buffer.get()) && buffer_var_defines_.count(alloc_key)) {
BufferEntry entry;
entry.buffer = buffer;
buf_map_[buffer.get()] = std::move(entry);
Expand All @@ -1138,7 +1166,7 @@ class BufferBindUnwrapper : public StmtExprMutator {
std::unordered_map<const BufferNode*, BufferEntry> buf_map_;
// Set of vars that have occurred in an AllocateNode, but haven't
// yet occurred in a BufferLoad/BufferStore.
std::unordered_set<const VarNode*> allocate_node_var_;
std::unordered_set<const VarNode*> buffer_var_defines_;
// Analyzer for the variable bounds, used to simplify the bounds populator. We really need the
// analyzer from it. However
IRVisitorWithAnalyzer* bound_analyzer_;
Expand Down Expand Up @@ -1376,15 +1404,29 @@ class StorageFlattener : public StmtExprMutator {
// be simplified in the future by having AllocateNode hold a buffer,
// rather than a buffer_var.
Stmt VisitStmt_(const AllocateNode* op) final {
allocate_node_var_.insert(op->buffer_var.get());
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AllocateConstNode* op) final {
allocate_node_var_.insert(op->buffer_var.get());
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const LetStmtNode* op) final {
if (op->var.dtype().is_handle()) {
buffer_var_defines_.insert(op->var.get());
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr VisitExpr_(const LetNode* op) final {
if (op->var.dtype().is_handle()) {
buffer_var_defines_.insert(op->var.get());
}
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const BufferRealizeNode* op) final {
const auto& key = op->buffer;

Expand Down Expand Up @@ -1598,7 +1640,7 @@ class StorageFlattener : public StmtExprMutator {

const BufferEntry& GetBufferEntry(Buffer buffer) {
auto alloc_key = buffer->data.get();
if (!buf_map_.count(buffer) && allocate_node_var_.count(alloc_key)) {
if (!buf_map_.count(buffer) && buffer_var_defines_.count(alloc_key)) {
BufferEntry entry;
entry.buffer = buffer;
entry.flattened_buffer = buffer.GetFlattenedBuffer();
Expand All @@ -1622,7 +1664,7 @@ class StorageFlattener : public StmtExprMutator {
std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// Set of vars that have occurred in an AllocateNode, but haven't
// yet occurred in a BufferLoad/BufferStore.
std::unordered_set<const VarNode*> allocate_node_var_;
std::unordered_set<const VarNode*> buffer_var_defines_;
// Buffer map
std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
// The extern buffer map, updated to include flattened buffers.
Expand Down
25 changes: 20 additions & 5 deletions tests/python/unittest/test_tir_transform_storage_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ def count_sync(op):
assert count[0] == 4


def test_flatten_let_buffer():
@tvm.script.ir_module
class module:
@T.prim_func
def main():
T.func_attr({"from_legacy_te_schedule": True})

# If a pointer defined using a LetStmt,
A_data: T.Ptr[T.int32] = T.call_extern("dummy_extern_function", dtype="handle")

# and a buffer is backed by that pointer,
A: T.Buffer = T.buffer_decl([1], dtype="float32", data=A_data)
T.evaluate(A[0])

# then the call to StorageFlatten would result in an exception
# being thrown.
tvm.tir.transform.StorageFlatten(64)(module)


@T.prim_func
def tir_func(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [2, 2])
Expand All @@ -146,8 +165,4 @@ def test_flatten_tir():


if __name__ == "__main__":
test_flatten2()
test_flatten_storage_align()
test_flatten_double_buffer()
test_flatten_prefetch()
test_flatten_tir()
sys.exit(pytest.main(sys.argv))