Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 1 addition & 42 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
* under the License.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/tir/stmt_functor.h>

#include "./utils.h"

Expand Down Expand Up @@ -65,47 +64,7 @@ bool IsSimpleBuffer(const tir::Buffer& buf) {
}

int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
int count = 0;
const tir::VarNode* v = nullptr;

void VisitExpr_(const tir::VarNode* op) final {
if (op == v) {
++count;
}
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::BufferStoreNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const tir::BufferLoadNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::DeclBufferNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const tir::BufferNode* buffer) {
VisitExpr(buffer->data);
for (const PrimExpr& shape_i : buffer->shape) {
VisitExpr(shape_i);
}
for (const PrimExpr& stride_i : buffer->strides) {
VisitExpr(stride_i);
}
VisitExpr(buffer->elem_offset);
}
};

OccurrenceCounter counter;
counter.v = v.get();
OccurrenceCounter counter(v.get());
counter(f->body);
for (const tir::Var& v : f->params) {
counter(v);
Expand Down
24 changes: 24 additions & 0 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,34 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}));
});

bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) {
const tir::Var& buffer_var = allocate->buffer_var;
if (const tir::DeclBufferNode* decl_buffer = allocate->body.as<tir::DeclBufferNode>()) {
const tir::Buffer& buffer = decl_buffer->buffer;
if (buffer_var.same_as(buffer->data) && allocate->dtype == buffer->dtype &&
tir::is_one(allocate->condition) && !allocate->annotations.size() &&
allocate->extents.size() == buffer->shape.size()) {
tir::ExprDeepEqual expr_equal;
for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
return false;
}
}
return true;
}
}
return false;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Allocate>( //
"", [](tir::Allocate stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
OccurrenceCounter counter(stmt->buffer_var.get());
counter(stmt->body);
if (counter.count == 1 && IsAllocateDeclBufferPattern(stmt.get())) {
return d->AsDoc(stmt->body, p->Attr("body"));
}
String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
Array<ExprDoc> args;
Array<String> kwargs_keys;
Expand Down
45 changes: 45 additions & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -220,6 +221,50 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d);

/*! \brief A Var occurrence counter visitor */
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
/*! \brief The occurrence counter */
int count = 0;
/*! \brief The Var to count occurrence */
const tir::VarNode* v = nullptr;

void VisitExpr_(const tir::VarNode* op) final {
if (op == v) {
++count;
}
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::BufferStoreNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const tir::BufferLoadNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::DeclBufferNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const tir::BufferNode* buffer) {
VisitExpr(buffer->data);
for (const PrimExpr& shape_i : buffer->shape) {
VisitExpr(shape_i);
}
for (const PrimExpr& stride_i : buffer->strides) {
VisitExpr(stride_i);
}
VisitExpr(buffer->elem_offset);
}

explicit OccurrenceCounter(const tir::VarNode* var) { v = var; }
};

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
92 changes: 49 additions & 43 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tir as T
from tvm.script.printer import default
import tvm.testing


@contextmanager
Expand Down Expand Up @@ -327,6 +328,53 @@ def test_allocate():
)


def test_allocate_with_decl_buffer_sugar():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer:
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
with T.decl_buffer((128, 128)) as buffer:
T.evaluate(0)
""",
)


def test_allocate_with_decl_buffer_no_sugar_multi_usage():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer:
T.evaluate(buffer_data)
obj = ib.get()
_assert_print(
obj,
"""
with T.allocate([128, 128], "float32", "global") as v:
buffer = T.decl_buffer((128, 128), data=v)
T.evaluate(v)
""",
)


def test_allocate_with_decl_buffer_no_sugar_mismatch():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([256, 256], "float32", data=buffer_data) as buffer:
T.evaluate(buffer_data)
obj = ib.get()
_assert_print(
obj,
"""
with T.allocate([128, 128], "float32", "global") as v:
buffer = T.decl_buffer((256, 256), data=v)
T.evaluate(v)
Comment on lines +371 to +373
Copy link
Member

@junrushao junrushao Jan 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we could still do sugaring in this case, i.e.:

with T.decl_buffer((256, 256)) as buffer:
    T.evaluate(buffer.data)

""",
)


def test_decl_buffer():
with IRBuilder() as ib:
with T.decl_buffer((10, 10), data=T.ptr("float32")):
Expand Down Expand Up @@ -686,46 +734,4 @@ def main():


if __name__ == "__main__":
test_prim_func()
test_prim_func_no_sugar_inlined_buffer()
test_prim_func_no_sugar_shared_buffer_data()
test_block_realize()
test_block()
test_buffer()
test_buffer_region()
test_buffer_load()
test_buffer_store()
test_match_buffer_region()
test_for()
test_let_stmt()
test_attr_stmt()
test_assert_stmt()
test_while()
test_allocate()
test_decl_buffer()
test_prefetch()
test_seq_stmt()
test_if_then_else()
test_evaluate()
test_buffer_realize()
test_var()
test_size_var()
test_iter_var()
test_string_imm()
test_cast()
test_binary_arith()
test_logical()
test_select()
test_ramp()
test_broadcast()
test_let_expr()
test_call()
test_comm_reducer()
test_any()
test_int_imm()
test_float_imm()
test_range()
test_prim_type()
test_pointer_type()
test_tuple_type()
test_remap()
tvm.testing.main()