From 02392b62bf35bff2e707c759216ec17c3007ee67 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Jun 2023 10:18:45 -0500 Subject: [PATCH 1/6] [TIR][CodeGen] Define PackedFunc error code in MakePackedAPI Previously, the return value of a PackedFunc was hard-coded as the string `"return 0;"` in `CodeGenCHost`, which could cause compilation errors for `PrimFunc` returning `DataType::Void()`. This PR removes this explicit return statement from `CodeGenCHost`, replacing it with `tir::ret(Integer(0))` in the `MakePackedAPI` and `MakeUnpackedAPI` transforms. This is related to https://github.com/apache/tvm/pull/15073, which performs an analogous change for the function signature. --- src/target/source/codegen_c.cc | 3 --- src/target/source/codegen_c.h | 4 ---- src/target/source/codegen_c_host.cc | 5 ----- src/target/source/codegen_c_host.h | 1 - src/tir/transforms/make_packed_api.cc | 9 +++++++-- src/tir/transforms/make_unpacked_api.cc | 6 +++++- 6 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index f6792c1a4e8b..c3929e76f4fc 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -122,7 +122,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) { this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); - this->PrintFinalReturn(); this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n\n"; @@ -132,8 +131,6 @@ void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; } void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} -void CodeGenC::PrintFinalReturn() {} - std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 4f0da5a9dbad..2775b7251bdd 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -110,10 +110,6 @@ class CodeGenC : public ExprFunctor, * Example: __launch_bounds__(256) for CUDA functions */ virtual void PrintExtraAttrs(const PrimFunc& f); - /*! - * \brief Print the final return at the end the function. - */ - virtual void PrintFinalReturn(); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 1d8071774e9e..f6b1808917be 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -125,11 +125,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) << "TVM_DLL int32_t"; } -void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) - this->PrintIndent(); - stream << "return 0;\n"; -} - std::string CodeGenCHost::Finish() { // NOLINT(*) std::string ret = decl_stream.str(); if (emit_fwd_func_decl_) { diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 6bae574627d5..828dc36f87a8 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -57,7 +57,6 @@ class CodeGenCHost : public CodeGenC { void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*) - void PrintFinalReturn() final; // NOLINT(*) // overload visitor functions void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index a6673a19ad01..d9883196ef94 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -353,11 +353,16 @@ PrimFunc MakePackedAPI(PrimFunc func) { } } + // Apply all argument assertions std::ostringstream num_args_error; num_args_error << name_hint << ": num_args should be " << num_args; std::vector arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())}; - func_ptr->body = - MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + + // Return error code of zero on success + body = SeqStmt({body, Evaluate(ret(Integer(0)))}); + + func_ptr->body = body; func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 4b1b3bf517d0..ebe3fd97ff10 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -147,7 +147,11 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); } - func_ptr->body = MergeNest(device_init, func_ptr->body); + Stmt body = func_ptr->body; + body = MergeNest(device_init, body); + body = SeqStmt({body, Evaluate(ret(Integer(0)))}); + + func_ptr->body = body; func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); func_ptr->buffer_map = Map(); From 1762438ee82aa997cecff498ee0e60f34c9a3f19 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Jun 2023 12:49:26 -0500 Subject: [PATCH 2/6] Handle builtin::ret() in CodeGenC --- src/target/source/codegen_c.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index c3929e76f4fc..27ff92507e7a 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -534,6 +534,9 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->args[0], os); os << " ) return "; PrintExpr(op->args[1], os); + } else if (op->op.same_as(builtin::ret())) { + os << "return "; + PrintExpr(op->args[0], os); } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); From 184e1267f3b5a078faa3402ed971b0bec09f3736 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 13 Jun 2023 09:13:34 -0500 Subject: [PATCH 3/6] Place T.ret(0) inside asserts, rather than outside This causes fewer unit tests to break, and has more readable TVMScript. --- src/tir/transforms/make_packed_api.cc | 6 +++--- src/tir/transforms/make_unpacked_api.cc | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d9883196ef94..e387204045dd 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -353,15 +353,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { } } + // Return error code of zero on success + body = SeqStmt({body, Evaluate(ret(Integer(0)))}); + // Apply all argument assertions std::ostringstream num_args_error; num_args_error << name_hint << ": num_args should be " << num_args; std::vector arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())}; body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); - // Return error code of zero on success - body = SeqStmt({body, Evaluate(ret(Integer(0)))}); - func_ptr->body = body; func_ptr->params = args; diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index ebe3fd97ff10..2646b5baea7c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -147,9 +147,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); } - Stmt body = func_ptr->body; - body = MergeNest(device_init, body); - body = SeqStmt({body, Evaluate(ret(Integer(0)))}); + Stmt body = MergeNest(device_init, SeqStmt({func_ptr->body, Evaluate(ret(Integer(0)))})); func_ptr->body = body; func_ptr->params = args; From ecf61435df2a7f18fcf27fc4685b275a2ccb3346 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 13 Jun 2023 09:14:08 -0500 Subject: [PATCH 4/6] Update unit tests to look inside SeqStmt --- .../test_tir_transform_lower_tvm_builtin.py | 9 +++++++-- .../test_tir_transform_make_packed_api.py | 15 ++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 2e0784cc3126..6eac5e90b553 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -70,8 +70,13 @@ def check_packed_func(target="llvm"): node = prim_func.body # Recursively visit PrimFunc until we meet the for-loop: - while isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)): - node = node.body + while True: + if isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)): + node = node.body + elif isinstance(node, tvm.tir.SeqStmt): + node = node[0] + else: + break # For-loop: assert isinstance(node, tvm.tir.stmt.For) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 34adcbb9aee4..6f84b6f6d48c 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -60,9 +60,18 @@ def _find_assignment(stmt, var_name): def _find_next(stmt, type): - while not isinstance(stmt, type): - stmt = stmt.body - return stmt + search_stack = [stmt] + + while search_stack: + stmt = search_stack.pop() + if isinstance(stmt, type): + return stmt + elif isinstance(stmt, tvm.tir.SeqStmt): + search_stack.extend(reversed(stmt)) + else: + search_stack.append(stmt.body) + + return None def _find_compute_scope(func): From c2d0712e12ea4691a6109db4b854162f547dbd40 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 13 Jun 2023 09:14:36 -0500 Subject: [PATCH 5/6] Handle T.ret(0) in CodeGenStackVM --- src/target/stackvm/codegen_stackvm.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index db6e32d65f04..fa2cd6b09d13 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -284,6 +284,12 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_HANDLE); + } else if (op->op.same_as(builtin::ret())) { + CHECK(op->args.size() == 1 && op->args[0]->IsInstance() && + op->args[0].as()->value == 0) + << "StackVM does not support return values, " + << "and the return value " << op->args + << " is not special case of returning an error code of zero."; } else { LOG(FATAL) << "unknown function call " << op->op; } From 60eed14694c6aff5cea5b87ea9f3345eec4e1893 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 13 Jun 2023 09:14:50 -0500 Subject: [PATCH 6/6] Update MakeUnpackedAPI tests to expect T.ret --- .../python/unittest/test_tir_transform_make_unpacked_api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 1931f7aef324..868d30db3618 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -172,6 +172,7 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 2) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")): @@ -215,6 +216,7 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")): @@ -259,11 +261,13 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")) -> T.int32: T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) T.evaluate(A_data) + T.ret(T.int32(0)) return mod @@ -316,6 +320,7 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")) -> T.int32: @@ -323,6 +328,7 @@ def subroutine(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) T.evaluate(A_data) + T.ret(T.int32(0)) return mod