From ba960896dac1ce13601ac95b5d2bf9945b932275 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 12 Dec 2022 23:49:44 -0800 Subject: [PATCH 1/2] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer --- .../plan_update_buffer_allocation_location.cc | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 4c63d3393fd8..27a8f3582c88 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -61,24 +61,33 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { } private: + bool find(const Buffer& buf) { + return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) != + buffer_alloc_recorder_.end(); + } + void VisitStmt_(const BlockNode* op) final { for (const Buffer& buffer : op->alloc_buffers) { buffer_alloc_recorder_.push_back(buffer); } + for (const auto& region : op->match_buffers) { + if (!find(region->source->buffer)) { + buffer_alloc_recorder_.push_back(region->source->buffer); + } + } + StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode* op) final { - if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) == - buffer_alloc_recorder_.end()) { + if (!find(op->buffer)) { buffer_alloc_recorder_.push_back(op->buffer); } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) == - buffer_alloc_recorder_.end()) { + if (!find(op->buffer)) { buffer_alloc_recorder_.push_back(op->buffer); } StmtExprVisitor::VisitStmt_(op); From bbd906cce42deb7d4afba828f2c36be61a550d0c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Dec 2022 17:11:52 +0900 Subject: [PATCH 2/2] add comment --- src/tir/transforms/plan_update_buffer_allocation_location.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 27a8f3582c88..11d8330ec8fe 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -70,6 +70,8 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { for (const Buffer& buffer : op->alloc_buffers) { buffer_alloc_recorder_.push_back(buffer); } + // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes. + // These buffers only appear in read and match_buffer regions. for (const auto& region : op->match_buffers) { if (!find(region->source->buffer)) { buffer_alloc_recorder_.push_back(region->source->buffer);