diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 0d57f7928f47..2bfc8420b025 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -533,15 +533,29 @@ class BufferStrideLegalize : public StmtExprMutator { // be simplified in the future by having AllocateNode hold a buffer, // rather than a buffer_var. Stmt VisitStmt_(const AllocateNode* op) final { - allocate_node_var_.insert(op->buffer_var.get()); + buffer_var_defines_.insert(op->buffer_var.get()); return StmtExprMutator::VisitStmt_(op); } Stmt VisitStmt_(const AllocateConstNode* op) final { - allocate_node_var_.insert(op->buffer_var.get()); + buffer_var_defines_.insert(op->buffer_var.get()); return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const LetStmtNode* op) final { + if (op->var.dtype().is_handle()) { + buffer_var_defines_.insert(op->var.get()); + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const LetNode* op) final { + if (op->var.dtype().is_handle()) { + buffer_var_defines_.insert(op->var.get()); + } + return StmtExprMutator::VisitExpr_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { Buffer key = op->buffer; Buffer with_strides = WithStrides(op->buffer); @@ -575,7 +589,7 @@ class BufferStrideLegalize : public StmtExprMutator { template Node VisitBufferAccess(Node node) { auto alloc_key = node->buffer->data.get(); - if (!buf_map_.count(node->buffer) && allocate_node_var_.count(alloc_key)) { + if (!buf_map_.count(node->buffer) && buffer_var_defines_.count(alloc_key)) { BufferEntry entry; entry.remap_to = WithStrides(node->buffer); entry.in_scope = true; @@ -615,7 +629,7 @@ class BufferStrideLegalize : public StmtExprMutator { // Set of vars that have occurred in an AllocateNode, but haven't // yet occurred in a BufferLoad/BufferStore. - std::unordered_set allocate_node_var_; + std::unordered_set buffer_var_defines_; IRVisitorWithAnalyzer* bound_analyzer_; }; @@ -918,15 +932,29 @@ class BufferBindUnwrapper : public StmtExprMutator { // be simplified in the future by having AllocateNode hold a buffer, // rather than a buffer_var. Stmt VisitStmt_(const AllocateNode* op) final { - allocate_node_var_.insert(op->buffer_var.get()); + buffer_var_defines_.insert(op->buffer_var.get()); return StmtExprMutator::VisitStmt_(op); } Stmt VisitStmt_(const AllocateConstNode* op) final { - allocate_node_var_.insert(op->buffer_var.get()); + buffer_var_defines_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const LetStmtNode* op) final { + if (op->var.dtype().is_handle()) { + buffer_var_defines_.insert(op->var.get()); + } return StmtExprMutator::VisitStmt_(op); } + PrimExpr VisitExpr_(const LetNode* op) final { + if (op->var.dtype().is_handle()) { + buffer_var_defines_.insert(op->var.get()); + } + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); @@ -1118,7 +1146,7 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& GetBufferEntry(Buffer buffer) { auto alloc_key = buffer->data.get(); - if (!buf_map_.count(buffer.get()) && allocate_node_var_.count(alloc_key)) { + if (!buf_map_.count(buffer.get()) && buffer_var_defines_.count(alloc_key)) { BufferEntry entry; entry.buffer = buffer; buf_map_[buffer.get()] = std::move(entry); @@ -1138,7 +1166,7 @@ class BufferBindUnwrapper : public StmtExprMutator { std::unordered_map buf_map_; // Set of vars that have occurred in an AllocateNode, but haven't // yet occurred in a BufferLoad/BufferStore. - std::unordered_set allocate_node_var_; + std::unordered_set buffer_var_defines_; // Analyzer for the variable bounds, used to simplify the bounds populator. We really need the // analyzer from it. However IRVisitorWithAnalyzer* bound_analyzer_; @@ -1376,15 +1404,29 @@ class StorageFlattener : public StmtExprMutator { // be simplified in the future by having AllocateNode hold a buffer, // rather than a buffer_var. Stmt VisitStmt_(const AllocateNode* op) final { - allocate_node_var_.insert(op->buffer_var.get()); + buffer_var_defines_.insert(op->buffer_var.get()); return StmtExprMutator::VisitStmt_(op); } Stmt VisitStmt_(const AllocateConstNode* op) final { - allocate_node_var_.insert(op->buffer_var.get()); + buffer_var_defines_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const LetStmtNode* op) final { + if (op->var.dtype().is_handle()) { + buffer_var_defines_.insert(op->var.get()); + } return StmtExprMutator::VisitStmt_(op); } + PrimExpr VisitExpr_(const LetNode* op) final { + if (op->var.dtype().is_handle()) { + buffer_var_defines_.insert(op->var.get()); + } + return StmtExprMutator::VisitExpr_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { const auto& key = op->buffer; @@ -1598,7 +1640,7 @@ class StorageFlattener : public StmtExprMutator { const BufferEntry& GetBufferEntry(Buffer buffer) { auto alloc_key = buffer->data.get(); - if (!buf_map_.count(buffer) && allocate_node_var_.count(alloc_key)) { + if (!buf_map_.count(buffer) && buffer_var_defines_.count(alloc_key)) { BufferEntry entry; entry.buffer = buffer; entry.flattened_buffer = buffer.GetFlattenedBuffer(); @@ -1622,7 +1664,7 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map var_remap_; // Set of vars that have occurred in an AllocateNode, but haven't // yet occurred in a BufferLoad/BufferStore. - std::unordered_set allocate_node_var_; + std::unordered_set buffer_var_defines_; // Buffer map std::unordered_map buf_map_; // The extern buffer map, updated to include flattened buffers. diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 17afe7881184..44db6181758f 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -130,6 +130,25 @@ def count_sync(op): assert count[0] == 4 +def test_flatten_let_buffer(): + @tvm.script.ir_module + class module: + @T.prim_func + def main(): + T.func_attr({"from_legacy_te_schedule": True}) + + # If a pointer defined using a LetStmt, + A_data: T.Ptr[T.int32] = T.call_extern("dummy_extern_function", dtype="handle") + + # and a buffer is backed by that pointer, + A: T.Buffer = T.buffer_decl([1], dtype="float32", data=A_data) + T.evaluate(A[0]) + + # then the call to StorageFlatten would result in an exception + # being thrown. + tvm.tir.transform.StorageFlatten(64)(module) + + @T.prim_func def tir_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [2, 2]) @@ -146,8 +165,4 @@ def test_flatten_tir(): if __name__ == "__main__": - test_flatten2() - test_flatten_storage_align() - test_flatten_double_buffer() - test_flatten_prefetch() - test_flatten_tir() + sys.exit(pytest.main(sys.argv))