diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index f6792c1a4e8b..27ff92507e7a 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -122,7 +122,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) { this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); - this->PrintFinalReturn(); this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n\n"; @@ -132,8 +131,6 @@ void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; } void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} -void CodeGenC::PrintFinalReturn() {} - std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) @@ -537,6 +534,9 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->args[0], os); os << " ) return "; PrintExpr(op->args[1], os); + } else if (op->op.same_as(builtin::ret())) { + os << "return "; + PrintExpr(op->args[0], os); } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 4f0da5a9dbad..2775b7251bdd 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -110,10 +110,6 @@ class CodeGenC : public ExprFunctor, * Example: __launch_bounds__(256) for CUDA functions */ virtual void PrintExtraAttrs(const PrimFunc& f); - /*! - * \brief Print the final return at the end the function. - */ - virtual void PrintFinalReturn(); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 1d8071774e9e..f6b1808917be 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -125,11 +125,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) << "TVM_DLL int32_t"; } -void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) - this->PrintIndent(); - stream << "return 0;\n"; -} - std::string CodeGenCHost::Finish() { // NOLINT(*) std::string ret = decl_stream.str(); if (emit_fwd_func_decl_) { diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 6bae574627d5..828dc36f87a8 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -57,7 +57,6 @@ class CodeGenCHost : public CodeGenC { void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*) - void PrintFinalReturn() final; // NOLINT(*) // overload visitor functions void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index db6e32d65f04..fa2cd6b09d13 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -284,6 +284,12 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_HANDLE); + } else if (op->op.same_as(builtin::ret())) { + CHECK(op->args.size() == 1 && op->args[0]->IsInstance() && + op->args[0].as()->value == 0) + << "StackVM does not support return values, " + << "and the return value " << op->args + << " is not special case of returning an error code of zero."; } else { LOG(FATAL) << "unknown function call " << op->op; } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index a6673a19ad01..e387204045dd 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -353,11 +353,16 @@ PrimFunc MakePackedAPI(PrimFunc func) { } } + // Return error code of zero on success + body = SeqStmt({body, Evaluate(ret(Integer(0)))}); + + // Apply all argument assertions std::ostringstream num_args_error; num_args_error << name_hint << ": num_args should be " << num_args; std::vector arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())}; - func_ptr->body = - MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + + func_ptr->body = body; func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 4b1b3bf517d0..2646b5baea7c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -147,7 +147,9 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); } - func_ptr->body = MergeNest(device_init, func_ptr->body); + Stmt body = MergeNest(device_init, SeqStmt({func_ptr->body, Evaluate(ret(Integer(0)))})); + + func_ptr->body = body; func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); func_ptr->buffer_map = Map(); diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 2e0784cc3126..6eac5e90b553 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -70,8 +70,13 @@ def check_packed_func(target="llvm"): node = prim_func.body # Recursively visit PrimFunc until we meet the for-loop: - while isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)): - node = node.body + while True: + if isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)): + node = node.body + elif isinstance(node, tvm.tir.SeqStmt): + node = node[0] + else: + break # For-loop: assert isinstance(node, tvm.tir.stmt.For) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 34adcbb9aee4..6f84b6f6d48c 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -60,9 +60,18 @@ def _find_assignment(stmt, var_name): def _find_next(stmt, type): - while not isinstance(stmt, type): - stmt = stmt.body - return stmt + search_stack = [stmt] + + while search_stack: + stmt = search_stack.pop() + if isinstance(stmt, type): + return stmt + elif isinstance(stmt, tvm.tir.SeqStmt): + search_stack.extend(reversed(stmt)) + else: + search_stack.append(stmt.body) + + return None def _find_compute_scope(func): diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 1931f7aef324..868d30db3618 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -172,6 +172,7 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 2) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")): @@ -215,6 +216,7 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")): @@ -259,11 +261,13 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")) -> T.int32: T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) T.evaluate(A_data) + T.ret(T.int32(0)) return mod @@ -316,6 +320,7 @@ def main(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) mod.subroutine(A_data) + T.ret(T.int32(0)) @T.prim_func def subroutine(A_data: T.handle("float32")) -> T.int32: @@ -323,6 +328,7 @@ def subroutine(A_data: T.handle("float32")) -> T.int32: T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) T.evaluate(A_data) + T.ret(T.int32(0)) return mod