From c2618812dcba04d043bcde2328badba5d04e375f Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 20 Jan 2023 14:37:22 -0800 Subject: [PATCH 1/2] root block syntax --- src/script/printer/tir/function.cc | 46 +++++++++++++++- .../unittest/test_tvmscript_printer_tir.py | 52 ++++++++++++++----- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 40957fcffaca..e479019d8603 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -131,7 +131,51 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // Step 4. Handle `func->body` - AsDocBody(func->body, p->Attr("body"), frame->get(), d); + Optional implicit_root_block = [&]() -> Optional { + const tir::BlockRealizeNode* root_block_realize = func->body.as(); + if (root_block_realize && !root_block_realize->iter_values.size()) { + tir::Block root_block = root_block_realize->block; + if (!root_block->annotations.size()) { + const tir::BlockRealizeNode* block_realize = + root_block->body.as(); + if (root_block->alloc_buffers.size() || + (block_realize && block_realize->block->iter_vars.size()) || + (!block_realize && tir::ContainsNode(root_block->body))) { + return root_block; + } + } + } + return NullOpt; + }(); + if (implicit_root_block) { + tir::Block root_block = implicit_root_block.value(); + ObjectPath root_block_p = p->Attr("body")->Attr("body"); + // Handle root block `alloc_buffer` + for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { + tir::Buffer buffer = root_block->alloc_buffers[i]; + ObjectPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayIndex(i); + IdDoc lhs = DefineBuffer(buffer, *frame, d); + ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d); + (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + } + // Handle root block `match_buffer` + for (int i = 0, n = root_block->match_buffers.size(); i < n; ++i) { + tir::MatchBufferRegion buffer_region = root_block->match_buffers[i]; + ObjectPath buffer_region_p = root_block_p->Attr("match_buffers")->ArrayIndex(i); + StmtDoc doc = d->AsDoc(buffer_region, buffer_region_p); + (*frame)->stmts.push_back(doc); + } + // Handle root block `init` block + if (root_block->init.defined()) { + tir::Stmt init = root_block->init.value(); + With init_frame(d, init); + AsDocBody(init, root_block_p->Attr("init"), init_frame->get(), d); + (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR("init")->Call({}), (*init_frame)->stmts)); + } + AsDocBody(root_block->body, root_block_p->Attr("body"), frame->get(), d); + } else { + AsDocBody(func->body, p->Attr("body"), frame->get(), d); + } Optional ret_type = NullOpt; if (func->ret_type.defined()) { const auto* as_tuple = func->ret_type.as(); diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 5d86a8860852..d57d10467077 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -717,21 +717,49 @@ def block_with_remap_explicitly(): expected_output = """@T.prim_func def main(): - with T.block("root"): - T.reads() - T.writes() - for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): - with T.block("update"): - v0 = T.axis.spatial(128, i0 + 1) - v1, v2 = T.axis.remap("SR", [i1, i2]) - v3 = T.axis.spatial(128, i3 - 1) - v4, v5 = T.axis.remap("RS", [i4, i5]) - T.reads() - T.writes() - T.evaluate(0)""" + for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): + with T.block("update"): + v0 = T.axis.spatial(128, i0 + 1) + v1, v2 = T.axis.remap("SR", [i1, i2]) + v3 = T.axis.spatial(128, i3 - 1) + v4, v5 = T.axis.remap("RS", [i4, i5]) + T.reads() + T.writes() + T.evaluate(0)""" _assert_print(block_with_remap_explicitly, expected_output) _assert_print(block_with_remap_implicitly, expected_output) +def test_root_block(): + from tvm.script import tir as T + + @T.prim_func + def root_block_implicitly(): + a = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block(): + T.evaluate(0) + + @T.prim_func + def root_block_explicitly(): + with T.block("root"): + a = T.alloc_buffer([128, 128]) + for i, j in T.grid(128, 128): + with T.block(): + T.evaluate(0) + + expected_output = """@T.prim_func +def main(): + a = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block(""): + T.reads() + T.writes() + T.evaluate(0) + """ + _assert_print(root_block_implicitly, expected_output) + _assert_print(root_block_explicitly, expected_output) + + if __name__ == "__main__": tvm.testing.main() From 0b63fb1c52744a4c8d2d5b368033902b0e73ca1e Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 20 Jan 2023 17:35:13 -0800 Subject: [PATCH 2/2] fix --- src/script/printer/tir/function.cc | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index e479019d8603..ea7d56e1656d 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -133,9 +133,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 4. Handle `func->body` Optional implicit_root_block = [&]() -> Optional { const tir::BlockRealizeNode* root_block_realize = func->body.as(); - if (root_block_realize && !root_block_realize->iter_values.size()) { + if (root_block_realize && !root_block_realize->iter_values.size() && + tir::is_one(root_block_realize->predicate)) { tir::Block root_block = root_block_realize->block; - if (!root_block->annotations.size()) { + if (!root_block->annotations.size() && !root_block->match_buffers.size() && + !root_block->reads.size() && !root_block->writes.size() && + !root_block->init.defined()) { const tir::BlockRealizeNode* block_realize = root_block->body.as(); if (root_block->alloc_buffers.size() || @@ -158,20 +161,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d); (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } - // Handle root block `match_buffer` - for (int i = 0, n = root_block->match_buffers.size(); i < n; ++i) { - tir::MatchBufferRegion buffer_region = root_block->match_buffers[i]; - ObjectPath buffer_region_p = root_block_p->Attr("match_buffers")->ArrayIndex(i); - StmtDoc doc = d->AsDoc(buffer_region, buffer_region_p); - (*frame)->stmts.push_back(doc); - } - // Handle root block `init` block - if (root_block->init.defined()) { - tir::Stmt init = root_block->init.value(); - With init_frame(d, init); - AsDocBody(init, root_block_p->Attr("init"), init_frame->get(), d); - (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR("init")->Call({}), (*init_frame)->stmts)); - } AsDocBody(root_block->body, root_block_p->Attr("body"), frame->get(), d); } else { AsDocBody(func->body, p->Attr("body"), frame->get(), d);