diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 7ead6e6ae6fb..f5352f83f718 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -84,7 +84,6 @@ class PrimFuncSpecializer : public StmtExprMutator { buffer_map.Set(var, new_buffer); if (!new_buffer.same_as(buffer)) { buffer_map_updated = true; - specializer.buffer_map_[buffer] = new_buffer; } } @@ -116,7 +115,7 @@ class PrimFuncSpecializer : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { // Step.0. Define buffer mappings which is allocated inside the block Array alloc_buffers = op->alloc_buffers.Map( - std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); + std::bind(&PrimFuncSpecializer::MutateBuffer, this, std::placeholders::_1)); // Step.1. Recursively visit block body Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -141,31 +140,21 @@ class PrimFuncSpecializer : public StmtExprMutator { } Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op != nullptr); - auto it = buffer_map_.find(op->buffer); - if (it == buffer_map_.end()) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->buffer = it->second; - return Stmt(n); + auto stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + auto new_buffer = MutateBuffer(stmt->buffer); + if (!new_buffer.same_as(stmt->buffer)) { + stmt.CopyOnWrite()->buffer = new_buffer; } + return std::move(stmt); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - ICHECK(op != nullptr); - auto it = buffer_map_.find(op->buffer); - if (it == buffer_map_.end()) { - return GetRef(op); - } else { - auto n = make_object(*op); - n->buffer = it->second; - return PrimExpr(n); + auto expr = Downcast(StmtExprMutator::VisitExpr_(op)); + auto new_buffer = MutateBuffer(expr->buffer); + if (!new_buffer.same_as(expr->buffer)) { + expr.CopyOnWrite()->buffer = new_buffer; } + return std::move(expr); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -198,22 +187,28 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Buffer MutateBuffer(const Buffer& buffer) { + if (auto it = buffer_map_.find(buffer); it != buffer_map_.end()) { + return it->second; + } + Array shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); Array strides = buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); - if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && - buffer->strides.same_as(strides)) { - return buffer; - } else { + Buffer new_buffer = buffer; + if (!buffer->elem_offset.same_as(elem_offset) || !buffer->shape.same_as(shape) || + !buffer->strides.same_as(strides)) { auto n = make_object(*buffer.get()); n->elem_offset = std::move(elem_offset); n->shape = std::move(shape); n->strides = std::move(strides); - return Buffer(n); + new_buffer = Buffer(n); } + + buffer_map_[buffer] = new_buffer; + return new_buffer; } Range MutateRange(const Range& range) { @@ -226,26 +221,14 @@ class PrimFuncSpecializer : public StmtExprMutator { } } - Buffer MutateAllocBuffer(const Buffer& alloc_buf) { - Buffer buf = MutateBuffer(alloc_buf); - if (buf.same_as(alloc_buf)) { - return alloc_buf; - } else { - ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end()); - buffer_map_[alloc_buf] = buf; - return buf; - } - } - BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { - auto it = buffer_map_.find(buffer_region->buffer); - const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer; + Buffer buffer = MutateBuffer(buffer_region->buffer); Array region = buffer_region->region.Map( std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); - if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { + if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { return buffer_region; } else { - return BufferRegion(buffer, std::move(region)); + return BufferRegion(std::move(buffer), std::move(region)); } } diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index ebae827ef5ad..fa21a3ef6411 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -242,10 +242,30 @@ def test_specialize_with_const_folding(): tvm.ir.assert_structural_equal(func, param_in_arith_exprs_n_16) +def test_specialize_with_pointer_size(): + """Rewrite implicit buffers + + The first appearance of a `tir::Buffer` may be at its usage in + BufferStore or BufferLoad. The specialized parameter should be + updated in these buffers as well as explicitly defined buffers. + """ + + @T.prim_func + def before(a_ptr: T.handle("float32"), n: T.int64): + A = T.Buffer(n, "float32", data=a_ptr) + for i in range(n): + A[i] = 0.0 + + @T.prim_func + def expected(a_ptr: T.handle("float32")): + A = T.Buffer(T.int64(16), "float32", data=a_ptr) + for i in range(T.int64(16)): + A[i] = 0.0 + + _, n = before.params + after = before.specialize({n: tvm.tir.const(16, n.dtype)}) + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": - test_specialize_nothing() - test_specialize_matmul() - test_specialize_elemwise() - test_specialize_mem_copy() - test_specialize_recursive_load() - test_specialize_with_const_folding() + tvm.testing.main()