diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a4caeee43604..6aa1aca69970 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -192,16 +192,13 @@ TVM_DLL Pass InstrumentBoundCheckers(); * - Map the values in the api_args to Var that is required by body. * - Insert assertions to check type/value of the passed arguments. * - * \param num_unpacked_args Number of arguments that - * are processed in plain form instead of packed form. - * * \note * The function signature have two cases * - * let num_packed_args = len(api_args) - num_unpacked_args; + * let num_packed_args = len(api_args); * * if num_packed_args is zero: - * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) + * f() * * if num_packed_args is not zero: * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, @@ -212,7 +209,7 @@ TVM_DLL Pass InstrumentBoundCheckers(); * * \return The pass. */ -TVM_DLL Pass MakePackedAPI(int num_unpacked_args); +TVM_DLL Pass MakePackedAPI(); /*! * \brief Transform the high-level PrimFunc to a C signature that can be used diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 324471c71891..3c1ca196f1b0 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -387,22 +387,15 @@ def LowerCustomDatatypes(): return _ffi_api.LowerCustomDatatypes() # type: ignore -def MakePackedAPI(num_unpacked_params: int = -1): +def MakePackedAPI(): """Transform the PrimFuncs in the module to a packed func API. - Parameters - ---------- - num_unpacked_params : int - Number of parameters that we hope to directly pass via normal arguments - following the PackedFunc input signature. If it is specified as -1 or it - is less than the number of arguments, the pass will packed arguments still. - Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MakePackedAPI(num_unpacked_params) # type: ignore + return _ffi_api.MakePackedAPI() # type: ignore def MakeUnpackedAPI(): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 1a617dcd494d..b460557da034 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -561,7 +561,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) if (unpacked_api) { mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); } else { - mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); + mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } mixed_pass_list.push_back(tir::transform::SplitHostDevice()); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 4f8ad1223cd2..bf7ff09c86c7 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -139,7 +139,7 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { +PrimFunc MakePackedAPI(PrimFunc&& func) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; @@ -152,14 +152,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto* func_ptr = func.CopyOnWrite(); const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); - ICHECK_LE(num_unpacked_args, num_args); - bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args); - if (num_unpacked_args == -1) { - // reset to zero - num_unpacked_args = 0; - } - ICHECK_GE(num_unpacked_args, 0); - int num_packed_args = num_args - num_unpacked_args; + // Data field definitions // The packed fields Var v_packed_args("args", DataType::Handle()); @@ -170,7 +163,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { Var v_out_ret_tcode("out_ret_tcode", PointerType(PrimType(DataType::Int(32)))); Var v_resource_handle("resource_handle", DataType::Handle()); // The arguments of the function. - Array args; + // The device context Var device_id("dev_id"); Integer device_type(target_device_type); @@ -194,14 +187,6 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } return res; }; - // --------------------------- - // start of logics - // add signature for packed arguments. - if (pack_args) { - args.push_back(v_packed_args); - args.push_back(buf_packed_arg_type_ids->data); - args.push_back(v_num_packed_args); - } // Need to re-declare vars, in case some arguments also appears in the buffer. std::vector> var_def; @@ -219,7 +204,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // Pluck the device API context out based on name if (param->name_hint == kDeviceContextVar) { - num_packed_args--; + num_args--; v_resource_handle = param; continue; } @@ -232,44 +217,34 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { var_def.emplace_back(v_arg, param); } - if (i < num_packed_args) { - // Value loads - seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); - // type code checks - Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back( - LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop)); - DataType t = v_arg.dtype(); - if (t.is_handle()) { - std::ostringstream msg; - msg << name_hint << ": Expect arg[" << i << "] to be pointer"; - seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || - tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, - tvm::tir::StringImm(msg.str()), nop)); - } else if (t.is_int() || t.is_uint()) { - std::ostringstream msg; - msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); - } else { - ICHECK(t.is_float()); - std::ostringstream msg; - msg << name_hint << ": Expect arg[" << i << "] to be float"; - seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); - } + // Value loads + seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); + // type code checks + Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); + seq_init.emplace_back( + LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop)); + DataType t = v_arg.dtype(); + if (t.is_handle()) { + std::ostringstream msg; + msg << name_hint << ": Expect arg[" << i << "] to be pointer"; + seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || + tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, + tvm::tir::StringImm(msg.str()), nop)); + } else if (t.is_int() || t.is_uint()) { + std::ostringstream msg; + msg << name_hint << ": Expect arg[" << i << "] to be int"; + seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); } else { - args.push_back(v_arg); + ICHECK(t.is_float()); + std::ostringstream msg; + msg << name_hint << ": Expect arg[" << i << "] to be float"; + seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); } } - // allow return value if the function is packed. - if (pack_args) { - args.push_back(v_out_ret_value); - args.push_back(v_out_ret_tcode); - args.push_back(v_resource_handle); - } - - size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0); - ICHECK_EQ(args.size(), expected_nargs); + Array args{v_packed_args, buf_packed_arg_type_ids->data, + v_num_packed_args, v_out_ret_value, + v_out_ret_tcode, v_resource_handle}; // Arg definitions are defined before buffer binding to avoid the use before // def errors. @@ -286,9 +261,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint); } - if (num_unpacked_args == 0) { - func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); - } + func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, @@ -307,16 +280,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } } - if (pack_args) { - std::ostringstream num_args_error; - num_args_error << name_hint << ": num_args should be " << num_packed_args; - std::vector arg_assert = { - MakeAssertEQ(v_num_packed_args, num_packed_args, num_args_error.str())}; - func_ptr->body = - MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); - } else { - func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); - } + std::ostringstream num_args_error; + num_args_error << name_hint << ": num_args should be " << num_args; + std::vector arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())}; + func_ptr->body = + MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); @@ -339,9 +307,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { namespace transform { -Pass MakePackedAPI(int num_unpacked_args) { - // packed arguments anyway while `num_unpacked_args` is -1 - auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) { +Pass MakePackedAPI() { + auto pass_func = [](IRModule m, PassContext ctx) { IRModuleNode* mptr = m.CopyOnWrite(); std::vector> updates; @@ -350,7 +317,7 @@ Pass MakePackedAPI(int num_unpacked_args) { PrimFunc func = GetRef(n); if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { - auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); + auto updated_func = MakePackedAPI(std::move(func)); updates.push_back({kv.first, updated_func}); } } @@ -365,7 +332,7 @@ Pass MakePackedAPI(int num_unpacked_args) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed(MakePackedAPI); +TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { return MakePackedAPI(); }); } // namespace transform } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 047c95b6134f..e78ed98d8569 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -39,9 +39,8 @@ def test_makeapi(): ) )(mod) - num_unpacked_args = 2 - f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] - assert len(f.params) == 8 + f = tvm.tir.transform.MakePackedAPI()(mod)["main"] + assert len(f.params) == 6 def _find_assignment(stmt, var_name):