From 3de4074f60e9749cd0647cf820ddb3bb47571965 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 14 Mar 2024 08:41:07 -0500 Subject: [PATCH] [Bugfix][TIR] Avoid overwrite of unmanaged buffer allocations Prior to this commit, the `tir.PlanAndUpdateBufferAllocationLocation` pass would attempt to merge buffer allocations, unless the buffer's backing allocation was found in a `Allocate`, `AllocateConst`, or `PrimFuncNode::params`. Previous PRs (e.g. https://github.com/apache/tvm/pull/10998) collected these locations and marked them as unmanaged. However, this requires exhaustively checking all locations where unmanaged allocations could occur. This PR updates `tir.PlanAndUpdateBufferAllocationLocation` to instead collect the managed buffers, and only perform rewrites of these managed buffers. This only required inspection of `BlockNode`, and no other constructs. The unit test added in this PR is another location where unmanaged buffers may be produced. --- .../plan_update_buffer_allocation_location.cc | 36 ++++++++----------- ..._plan_update_buffer_allocation_location.py | 33 ++++++++++++++++- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 8b3a2d370df1..f9ce708c78b7 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -32,21 +32,21 @@ namespace tvm { namespace tir { -class CollectUnmanagedAllocations : public StmtExprVisitor { +class CollectManagedAllocations : public StmtExprVisitor { public: - void VisitStmt_(const AllocateNode* op) final { - unmanaged_allocations.insert(op->buffer_var.get()); - StmtExprVisitor::VisitStmt_(op); - } - - void VisitStmt_(const AllocateConstNode* op) final { - unmanaged_allocations.insert(op->buffer_var.get()); + void VisitStmt_(const BlockNode* op) final { + for (const auto& buf : op->alloc_buffers) { + managed_allocations.insert(buf->data.get()); + } + for (const auto& buf : op->match_buffers) { + managed_allocations.insert(buf->buffer->data.get()); + } StmtExprVisitor::VisitStmt_(op); } /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by * BufferAllocationLocator. */ - std::unordered_set unmanaged_allocations; + std::unordered_set managed_allocations; }; /*! \brief Collect the allocate buffer order. */ @@ -108,15 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator { // since the buffer_lca Map is unordered. Array buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func); std::unordered_set arg_buffer_vars; - CollectUnmanagedAllocations collector; + CollectManagedAllocations collector; collector(func->body); - unmanaged_allocations_ = collector.unmanaged_allocations; - - for (const Var& param : func->params) { - if (param->type_annotation.defined() && param->type_annotation.as()) { - unmanaged_allocations_.insert(param.get()); - } - } + managed_allocations_ = collector.managed_allocations; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -131,7 +125,7 @@ class BufferAllocationLocator : public StmtExprMutator { if (arg_buffer_vars.count(buffer->data.get())) { continue; } - if (!unmanaged_allocations_.count(buffer->data.get())) { + if (managed_allocations_.count(buffer->data.get())) { alloc_buffers_[stmt].push_back(buffer); } buffer_data_to_buffer_.Set(buffer->data, buffer); @@ -152,7 +146,7 @@ class BufferAllocationLocator : public StmtExprMutator { Array new_block_alloc_bufs; for (const Buffer& buf : it->second) { - if (!unmanaged_allocations_.count(buf->data.get())) { + if (managed_allocations_.count(buf->data.get())) { buffer_data_to_buffer_.erase(buf->data); new_block_alloc_bufs.push_back(buf); } @@ -243,8 +237,8 @@ class BufferAllocationLocator : public StmtExprMutator { std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ Map buffer_data_to_buffer_; - /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */ - std::unordered_set unmanaged_allocations_; + /*! \brief Buffers that are allocated within a BlockNode, and may be moved. */ + std::unordered_set managed_allocations_; }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { diff --git a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py index fe724ad0c981..bb76bd235f15 100644 --- a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py @@ -417,7 +417,8 @@ def test_allocate_const_after_tensorize(): def test_buffer_conditional_lowering(): - """ + """Buffers passed as pointer arguments are unmodified + Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass leaves (Buffer nodes corresponding to pointer-typed PrimFunc arguments) unchanged, rather than lowering them to `reads`, `writes`, and `alloc_buffer` nodes. @@ -434,5 +435,35 @@ def before(A: T.handle("float32")): _check(before, after) +def test_dltensor_buffer_is_unlowered(): + """Buffers allocated with a LetStmt are unmodified + + Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass + leaves (Buffer nodes corresponding to PrimFunc DLTensor arguments) + unchanged, rather than lowering them to `reads`, `writes`, and + `alloc_buffer` nodes. + """ + + @T.prim_func + def before(dlpack_handle: T.handle, axis: T.int64) -> T.int64: + ndim: T.int32 = T.tvm_struct_get(dlpack_handle, 0, 5, "int32") + stride_ptr: T.handle("int64") = T.tvm_struct_get(dlpack_handle, 0, 4, "handle") + if T.isnullptr(stride_ptr): + shape_ptr: T.handle("int64") = T.tvm_struct_get(dlpack_handle, 0, 3, "handle") + shape = T.decl_buffer(ndim, "int64", data=shape_ptr) + product = T.decl_buffer([], "int64") + product[()] = 1 + for dim in range(axis + 1, ndim): + product[()] = product[()] * shape[dim] + return product[()] + else: + strides = T.decl_buffer(ndim, "int64", data=stride_ptr) + stride: T.int64 = strides[axis] + return stride + + after = before + _check(before, after) + + if __name__ == "__main__": tvm.testing.main()