diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 6094eefb65b1..40957fcffaca 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -17,7 +17,6 @@ * under the License. */ #include -#include #include "./utils.h" @@ -65,47 +64,7 @@ bool IsSimpleBuffer(const tir::Buffer& buf) { } int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { - class OccurrenceCounter : public tir::StmtExprVisitor { - public: - int count = 0; - const tir::VarNode* v = nullptr; - - void VisitExpr_(const tir::VarNode* op) final { - if (op == v) { - ++count; - } - tir::StmtExprVisitor::VisitExpr_(op); - } - - void VisitStmt_(const tir::BufferStoreNode* op) final { - VisitBuffer(op->buffer.get()); - tir::StmtExprVisitor::VisitStmt_(op); - } - - void VisitExpr_(const tir::BufferLoadNode* op) final { - VisitBuffer(op->buffer.get()); - tir::StmtExprVisitor::VisitExpr_(op); - } - - void VisitStmt_(const tir::DeclBufferNode* op) final { - VisitBuffer(op->buffer.get()); - tir::StmtExprVisitor::VisitStmt_(op); - } - - void VisitBuffer(const tir::BufferNode* buffer) { - VisitExpr(buffer->data); - for (const PrimExpr& shape_i : buffer->shape) { - VisitExpr(shape_i); - } - for (const PrimExpr& stride_i : buffer->strides) { - VisitExpr(stride_i); - } - VisitExpr(buffer->elem_offset); - } - }; - - OccurrenceCounter counter; - counter.v = v.get(); + OccurrenceCounter counter(v.get()); counter(f->body); for (const tir::Var& v : f->params) { counter(v); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 7344cb4d98d5..57b4c695a4ee 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -152,10 +152,34 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) })); }); +bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) { + const tir::Var& buffer_var = allocate->buffer_var; + if (const tir::DeclBufferNode* decl_buffer = allocate->body.as()) { + const tir::Buffer& buffer = decl_buffer->buffer; + if (buffer_var.same_as(buffer->data) && allocate->dtype == buffer->dtype && + tir::is_one(allocate->condition) && !allocate->annotations.size() && + allocate->extents.size() == buffer->shape.size()) { + tir::ExprDeepEqual expr_equal; + for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) { + if (!expr_equal(allocate->extents[i], buffer->shape[i])) { + return false; + } + } + return true; + } + } + return false; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::Allocate stmt, ObjectPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); + OccurrenceCounter counter(stmt->buffer_var.get()); + counter(stmt->body); + if (counter.count == 1 && IsAllocateDeclBufferPattern(stmt.get())) { + return d->AsDoc(stmt->body, p->Attr("body")); + } String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); Array args; Array kwargs_keys; diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 183400d974ca..e1ffe135229e 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -220,6 +221,50 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, const IRDocsifier& d); +/*! \brief A Var occurrence counter visitor */ +class OccurrenceCounter : public tir::StmtExprVisitor { + public: + /*! \brief The occurrence counter */ + int count = 0; + /*! \brief The Var to count occurrence */ + const tir::VarNode* v = nullptr; + + void VisitExpr_(const tir::VarNode* op) final { + if (op == v) { + ++count; + } + tir::StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const tir::BufferStoreNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const tir::BufferLoadNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const tir::DeclBufferNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitStmt_(op); + } + + void VisitBuffer(const tir::BufferNode* buffer) { + VisitExpr(buffer->data); + for (const PrimExpr& shape_i : buffer->shape) { + VisitExpr(shape_i); + } + for (const PrimExpr& stride_i : buffer->strides) { + VisitExpr(stride_i); + } + VisitExpr(buffer->elem_offset); + } + + explicit OccurrenceCounter(const tir::VarNode* var) { v = var; } +}; + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 201428b74c66..5d86a8860852 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -22,6 +22,7 @@ from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tir as T from tvm.script.printer import default +import tvm.testing @contextmanager @@ -327,6 +328,53 @@ def test_allocate(): ) +def test_allocate_with_decl_buffer_sugar(): + with IRBuilder() as ib: + with T.allocate([128, 128], "float32") as buffer_data: + with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer: + T.evaluate(0) + obj = ib.get() + _assert_print( + obj, + """ +with T.decl_buffer((128, 128)) as buffer: + T.evaluate(0) +""", + ) + + +def test_allocate_with_decl_buffer_no_sugar_multi_usage(): + with IRBuilder() as ib: + with T.allocate([128, 128], "float32") as buffer_data: + with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer: + T.evaluate(buffer_data) + obj = ib.get() + _assert_print( + obj, + """ +with T.allocate([128, 128], "float32", "global") as v: + buffer = T.decl_buffer((128, 128), data=v) + T.evaluate(v) +""", + ) + + +def test_allocate_with_decl_buffer_no_sugar_mismatch(): + with IRBuilder() as ib: + with T.allocate([128, 128], "float32") as buffer_data: + with T.decl_buffer([256, 256], "float32", data=buffer_data) as buffer: + T.evaluate(buffer_data) + obj = ib.get() + _assert_print( + obj, + """ +with T.allocate([128, 128], "float32", "global") as v: + buffer = T.decl_buffer((256, 256), data=v) + T.evaluate(v) +""", + ) + + def test_decl_buffer(): with IRBuilder() as ib: with T.decl_buffer((10, 10), data=T.ptr("float32")): @@ -686,46 +734,4 @@ def main(): if __name__ == "__main__": - test_prim_func() - test_prim_func_no_sugar_inlined_buffer() - test_prim_func_no_sugar_shared_buffer_data() - test_block_realize() - test_block() - test_buffer() - test_buffer_region() - test_buffer_load() - test_buffer_store() - test_match_buffer_region() - test_for() - test_let_stmt() - test_attr_stmt() - test_assert_stmt() - test_while() - test_allocate() - test_decl_buffer() - test_prefetch() - test_seq_stmt() - test_if_then_else() - test_evaluate() - test_buffer_realize() - test_var() - test_size_var() - test_iter_var() - test_string_imm() - test_cast() - test_binary_arith() - test_logical() - test_select() - test_ramp() - test_broadcast() - test_let_expr() - test_call() - test_comm_reducer() - test_any() - test_int_imm() - test_float_imm() - test_range() - test_prim_type() - test_pointer_type() - test_tuple_type() - test_remap() + tvm.testing.main()