diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index d830ea579aa7..35079c74061b 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -750,6 +750,50 @@ TVM_DLL const Op& start_profile_intrinsic(); */ TVM_DLL const Op& end_profile_intrinsic(); +/*! + * \brief Get a item from any list and return it. + * + * Any anylist_getitem(Handle anylist, + * int index) + * return anylist[index]; + * } + * + * \note This intrinsic is only applicable when appearing + * in call_packed and anylist_setitem_call_packed. + */ +TVM_DLL const Op& anylist_getitem(); + +/*! + * \brief Reset and clear a item in any list. + * + * void anylist_resetitem(Handle anylist, + * int index) + * anylist[index] = nullptr; + * } + * + * \note This intrinsic is only applicable when appearing + * in call_packed and anylist_setitem_call_packed. + */ +TVM_DLL const Op& anylist_resetitem(); + +/*! + * \brief Set an item into any list by running packed function call. + * + * void anylist_setitem_call_packed(Handle anylist, + * int index, + * name, *args) + * + * anylist[index] = call_packed(name, *args) + * } + * \note This intrinsic can be used in combination with anylist_getitem. + */ +TVM_DLL const Op& anylist_setitem_call_packed(); + +/*! + * \brief Same as anylist_setitem_call_packed but use C calling convention. + */ +TVM_DLL const Op& anylist_setitem_call_cpacked(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5f4e9d4f2cf0..601963565fff 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1713,6 +1713,10 @@ def wrapped(*args, **kwargs): TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) +anylist_getitem = _op_wrapper(_tir_op.anylist_getitem) +anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem) +anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed) +anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) def _dtype_forward(func): @@ -1988,6 +1992,10 @@ def wrapped(*args, **kwargs): "start_profile_intrinsic", "end_profile_intrinsic", "meta_var", + "anylist_getitem", + "anylist_resetitem", + "anylist_setitem_call_packed", + "anylist_setitem_call_cpacked", "llvm_lookup_intrinsic_id", "type_annotation", "broadcast", diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0a9c4fdfaa52..14decca77e51 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -2931,6 +2931,74 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) +def anylist_getitem(list_handle, index): + """Returns an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.anylist_getitem", list_handle, index) + + +def anylist_resetitem(list_handle, index): + """Reset an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int", "tir.anylist_resetitem", list_handle, index) + + +def anylist_setitem_call_packed(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "int", "tir.anylist_setitem_call_packed", list_handle, index, func_name, *args + ) + + +def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "int", "tir.anylist_setitem_call_cpacked", list_handle, index, func_name, *args + ) + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc new file mode 100644 index 000000000000..2f63a50d370f --- /dev/null +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_tir.cc + * \brief A codegen to generate VMTIR function(that can be compiled) from executable. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace relax_vm { + +using vm::VMFuncInfo; + +/*! + * \brief A class to generate VMTIR for Relax functions. + * + * \note Skip CallPacked with special attrs for now, as they can be + * further simplified with PrimValue. + */ +class CodeGenVMTIR : public ExprFunctor(const Expr&)> { + public: + explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod) + : builder_(builder), ctx_mod_(ctx_mod) {} + + static IRModule Run(relax::ExecBuilder builder, IRModule mod) { + // create a new copy + IRModule res_mod = mod; + res_mod.CopyOnWrite(); + + CodeGenVMTIR codegen(builder, mod); + // Remove relax function and turn into TIR func. + for (auto& p : mod->functions) { + if (auto* func = p.second.as()) { + auto tir_func = codegen.Codegen(GetRef(func)); + auto gsymbol = tir_func->GetAttr(tvm::attr::kGlobalSymbol); + res_mod->Add(GlobalVar(gsymbol.value()), tir_func); + res_mod->Remove(p.first); + } + } + return res_mod; + } + + private: + int64_t NewRegister() { return registers_num_++; } + + static IntImm ConstInt64(int64_t value) { return IntImm(DataType::Int(64), value); } + + static IntImm ConstInt32(int64_t value) { return IntImm(DataType::Int(32), value); } + + PrimExpr RegListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {reg_anylist_handle_, ConstInt32(slot)}); + } + + PrimExpr ConstListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {const_anylist_handle_, ConstInt32(slot)}); + } + + PrimExpr FuncListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {func_anylist_handle_, ConstInt32(slot)}); + } + + void EmitStmt(tir::Stmt stmt) { + ICHECK(!stmt_stack_.empty()); + stmt_stack_.back().emplace_back(stmt); + } + + void EmitCallPacked(String name, const Array& args, int64_t dst_anylist_slot = -1) { + Array all_args; + // negative index indicate return value can be discarded, emit call_packed + if (dst_anylist_slot >= 0) { + all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; + } + all_args.push_back(tir::StringImm(name)); + for (PrimExpr arg : args) { + all_args.push_back(arg); + } + if (dst_anylist_slot >= 0) { + this->EmitStmt(tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_packed(), all_args))); + } else { + this->EmitStmt( + tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), all_args))); + } + } + + void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array& args, + int64_t dst_anylist_slot = -1) { + Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "All functions must have global symbol at this phase"; + Array all_args; + // negative index indicate return value can be discarded, emit call_packed + if (dst_anylist_slot >= 0) { + all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; + } + all_args.push_back(tir::StringImm(gsymbol.value())); + for (PrimExpr arg : args) { + all_args.push_back(arg); + } + // push an empty handle to be compatible with current cpacked convention + // TODO(tqchen): revisit C Packed convention + all_args.push_back(tir::make_zero(DataType::Handle())); + if (dst_anylist_slot >= 0) { + this->EmitStmt(tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_cpacked(), all_args))); + } else { + this->EmitStmt( + tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), all_args))); + } + } + + tir::PrimFunc Codegen(const Function& func) { + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + // initialize the state + stmt_stack_ = {}; + registers_num_ = 0; + var_map_.clear(); + ctx_ptr_ = tir::Var("ctx_ptr", DataType::Handle()); + reg_anylist_handle_ = tir::Var("r", DataType::Handle()); + func_anylist_handle_ = tir::Var("f", DataType::Handle()); + const_anylist_handle_ = tir::Var("c", DataType::Handle()); + + Array param_names; + for (Var param : func->params) { + param_names.push_back(param->name_hint()); + } + // declare this function. + builder_->DeclareFunction(gsymbol.value(), vm::VMFuncInfo::FuncKind::kVMTIRFunc); + + for (size_t i = 0; i < func->params.size(); ++i) { + int64_t r = NewRegister(); + ICHECK_EQ(static_cast(r), i); + this->var_map_.insert({func->params[i], RegListGet(r)}); + } + size_t ret_reg = NewRegister(); + + tir::Stmt body = WithNewScope([&]() { + Optional ret = ExprFunctor::VisitExpr(func->body); + if (ret.defined()) { + this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg); + } + }); + + // Mark the function entry internally. + builder_->EmitFunction(gsymbol.value(), param_names.size(), param_names, + VMFuncInfo::FuncKind::kVMTIRFunc, registers_num_); + builder_->EndFunction(gsymbol.value()); + + Type ret_type = VoidType(); + Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, + func_anylist_handle_}; + String tir_func_name = "__vmtir__" + gsymbol.value(); + tir::PrimFunc tir_func(tir_params, body, ret_type, {}); + tir_func = WithAttr(tir_func, "global_symbol", tir_func_name); + registers_num_ = 0; + var_map_.clear(); + stmt_stack_.clear(); + return tir_func; + } + + Optional VisitExpr_(const SeqExprNode* op) final { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + Optional value; + if (auto* var_binding = binding.as()) { + value = this->VisitExpr(var_binding->value); + } else if (auto* match_cast = binding.as()) { + value = this->VisitExpr(match_cast->value); + } else { + LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + } + this->var_map_.insert({binding->var, value}); + } + } + return this->VisitExpr(op->body); + } + + Optional VisitExpr_(const CallNode* call_node) final { + Call call = GetRef(call_node); + + if (call_node->op == null_value_op_) { + return tir::Call(DataType::Handle(), tir::builtin::reinterpret(), + {IntImm(DataType::Int(64), 0)}); + } + int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); + if (call->op.as()) { + if (call_node->op == call_builtin_with_ctx_op_) { + EmitCallBuiltinWithCtx(call, dst_reg); + } else if (call_node->op == alloc_storage_op_) { + EmitAllocStorage(call, dst_reg); + } else if (call_node->op == alloc_tensor_op_) { + EmitAllocTensor(call, dst_reg); + } else { + // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those + // ops are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVMTIR cannot handle this intrinsic now:\n" << call_node->op; + } + } else { + EmitNormalCall(call, dst_reg); + } + if (dst_reg >= 0) { + return RegListGet(dst_reg); + } else { + return NullOpt; + } + } + + Optional VisitExpr_(const IfNode* op) final { + // Reserve a register for return + size_t merge_register = NewRegister(); + PrimExpr cond_value = this->VisitExpr(op->cond).value(); + + // turn ndarray cond value into scalar. + cond_value = tir::Cast(DataType::Bool(), + tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), + {tir::StringImm("vm.builtin.read_if_cond"), cond_value})); + + tir::Stmt true_branch = WithNewScope([&]() { + PrimExpr true_value = this->VisitExpr(op->true_branch).value(); + this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register); + }); + tir::Stmt false_branch = WithNewScope([&]() { + PrimExpr false_value = this->VisitExpr(op->false_branch).value(); + this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register); + }); + this->EmitStmt(tir::IfThenElse(cond_value, true_branch, false_branch)); + return RegListGet(merge_register); + } + + Optional VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = this->var_map_.find(var); + ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; + return it->second; + } + + Optional VisitExpr_(const ConstantNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->data).value()); + } + + Optional VisitExpr_(const ShapeExprNode* op) final { + std::vector shape; + for (PrimExpr e : op->values) { + if (auto* int_value = e.as()) { + shape.push_back(int_value->value); + } else { + LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + } + } + return ConstListGet(builder_->ConvertConstant(ShapeTuple(shape)).value()); + } + + Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } + + Optional VisitExpr_(const StringImmNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->value).value()); + } + + Optional VisitExpr_(const DataTypeImmNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->value).value()); + } + + Optional VisitExpr_(const TupleNode* op) final { + Tuple tuple = GetRef(op); + Array args; + for (auto arg : tuple->fields) { + args.push_back(this->VisitExpr(arg).value()); + } + int32_t dst_register = NewRegister(); + this->EmitCallPacked("vm.builtin.make_tuple", args, dst_register); + return RegListGet(dst_register); + } + + Optional VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = GetRef(op); + Array args = {this->VisitExpr(expr->tuple).value()}; + + args.push_back(ConstInt64(expr->index)); + + int64_t dst_register = NewRegister(); + this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); + return RegListGet(dst_register); + } + + // Lookup the function and see if it matches + Optional LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { + if (auto* ext_func = expr.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return ext_func->global_symbol; + } else if (auto* gvar_ptr = expr.as()) { + GlobalVar gvar = GetRef(gvar_ptr); + // Run a look up in the env to see if it maps to an extern func. + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* efunc = func.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return efunc->global_symbol; + } else if (func.as()) { + *kind = VMFuncInfo::FuncKind::kVMTIRFunc; + return gvar->name_hint; + } else if (func.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } else { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } + } + LOG(WARNING) << "Undefined global var " << gvar->name_hint; + // undefined global var, consider eliminate later. + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } else { + return NullOpt; + } + } + // Lookup PrimFunc in the same module + // We can do direct PrimFunc call in such cases + Optional LookupPrimFunc(const String& name) { + if (!ctx_mod_->ContainGlobalVar(name)) return NullOpt; + + GlobalVar gvar = ctx_mod_->GetGlobalVar(name); + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* prim_func = func.as()) { + return GetRef(prim_func); + } + } + return NullOpt; + } + + Optional VisitExpr_(const GlobalVarNode* op) final { + VMFuncInfo::FuncKind kind; + auto symbol = LookupFunction(GetRef(op), &kind); + ICHECK(symbol.defined()); + builder_->DeclareFunction(symbol.value(), kind); + return FuncListGet(builder_->GetFunction(symbol.value()).value()); + } + + Optional VisitExpr_(const ExternFuncNode* op) final { + builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); + return FuncListGet(builder_->GetFunction(op->global_symbol).value()); + } + + void EmitAllocStorage(const Call& call_node, int64_t dst_reg) { + // Handle args of the call + Array args; + args.push_back(ctx_ptr_); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg).value()); + } + this->EmitCallPacked("vm.builtin.alloc_storage", args, dst_reg); + } + + void EmitAllocTensor(const Call& call_node, int64_t dst_reg) { + ICHECK_EQ(call_node->args.size(), 4); + Array args; + args.reserve(4); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg).value()); + } + this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg); + } + + void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { + Array args; + // if context is required, pass as first argument. + args.push_back(ctx_ptr_); + auto* func = call_node->args[0].as(); + ICHECK(func) << "CallBuiltin comes with extern func"; + + auto tuple_arg = Downcast(call_node->args[1]); + + // Handle args of the call + for (Expr arg : tuple_arg->fields) { + args.push_back(this->VisitExpr(arg).value()); + } + + this->EmitCallPacked(func->global_symbol, args, dst_reg); + } + + void EmitNormalCall(const Call& call_node, int64_t dst_reg) { + Array args = VisitArray(call_node->args); + // A function can be a closure that comes from parent + // Do call closure to be safe. + VMFuncInfo::FuncKind kind; + auto symbol = LookupFunction(call_node->op, &kind); + + if (symbol.defined() && kind == VMFuncInfo::FuncKind::kPackedFunc) { + // primfunc in the same module. + // use cpacked to directly invoke without named based lookup + if (Optional prim_func = LookupPrimFunc(symbol.value())) { + this->EmitCallCPacked(prim_func.value(), args, dst_reg); + } else { + this->EmitCallPacked(symbol.value(), args, dst_reg); + } + } else { + // Default path, leverage function table and invoke as closure + Array all_args; + all_args.push_back(ctx_ptr_); + all_args.push_back(this->VisitExpr(call_node->op).value()); + for (auto arg : args) { + all_args.push_back(arg); + } + this->EmitCallPacked("vm.builtin.invoke_closure", all_args, dst_reg); + } + } + + template + tir::Stmt WithNewScope(const FLambda& callback) { + stmt_stack_.push_back({}); + callback(); + tir::Stmt stmt = tir::SeqStmt::Flatten(stmt_stack_.back()); + stmt_stack_.pop_back(); + return stmt; + } + + Array VisitArray(const Array& arr) { + Array ret; + for (size_t i = 0; i < arr.size(); ++i) { + ret.push_back(this->VisitExpr(arr[i]).value()); + } + return ret; + } + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! \brief List to ctx_ptr */ + tir::Var ctx_ptr_; + /*! \brief List to store temp object registers */ + tir::Var reg_anylist_handle_; + /*! \brief List to store closures */ + tir::Var func_anylist_handle_; + /*! \brief List to store constants */ + tir::Var const_anylist_handle_; + /*! + * \brief Total number of virtual registers allocated. + * \note The first two registers are reserved for special registers. + */ + int64_t registers_num_ = 0; + /*! \brief Stack to build up statements */ + std::vector> stmt_stack_; + /*! \brief Map from var to Expr. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> var_map_; + /*! \brief the context module. */ + IRModule ctx_mod_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); +}; + +/*! + * \brief Create the Relax VM executable from all relax.Function in mod. + * and add them to exec_builder. Create extra TIR functions. + * + * \param exec_builder Builder to collect executables. + * \param mod Input module. + * \return Extra TIR module created. + */ +IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) { + return CodeGenVMTIR::Run(exec_builder, mod); +} + +TVM_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 54fd362387c5..1a12308413b2 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -77,7 +77,10 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& int ret_type_code = kTVMNullptr; int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), args.num_args, &ret_value, &ret_type_code, nullptr); - ICHECK_EQ(ret, 0) << TVMGetLastError(); + // NOTE: important to keep the original error message. + if (ret != 0) { + LOG(FATAL) << TVMGetLastError(); + } if (ret_type_code != kTVMNullptr) { *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 21d2c6ebe0a5..10aa2688a846 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -905,8 +905,10 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); } - - nargs -= 1; + // NOTE: This is a bugfix to a previous coupled convention(in lower_tvm_builtin) + // The begin, end should correspond to the right location in cpacked excluding resource handle. + // TODO(tqchen): upstream the fix. + // nargs -= 1; call_args.insert(call_args.end(), { builder_->CreateBitCast(arg_value, t_void_p_), arg_tcode.addr, diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index dc3208f484e3..cdba1346f0ae 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -316,6 +316,18 @@ TIR_DEFINE_BUILTIN_FUNC(start_profile_intrinsic) TIR_DEFINE_BUILTIN_FUNC(end_profile_intrinsic) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(anylist_getitem) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); + +TIR_DEFINE_BUILTIN_FUNC(anylist_resetitem) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TGlobalSymbol", "TVMBackendAnyListResetItem"); + +TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc new file mode 100644 index 000000000000..9ee6c67ec96b --- /dev/null +++ b/src/tir/op/runtime.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/op/runtime.cc + * \brief TIR ops for runtime functions. + */ +#include +#include + +namespace tvm { +namespace tir { + +TVM_REGISTER_OP("tir.TVMBackendAnyListSetPackedArg") + .set_num_inputs(5) + .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.TVMBackendAnyListMoveFromPackedReturn") + .set_num_inputs(3) + .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 082a54f9c73d..b0a87a3056b4 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -302,13 +302,21 @@ class BuiltinLower : public StmtExprMutator { return Stmt(n); } } + PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { - return MakeCallPacked(op, /* use_string_lookup */ true); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(), + /* use_string_lookup */ true); } else if (op->op.same_as(builtin::tvm_call_cpacked())) { - return MakeCallPacked(op, /* use_string_lookup */ false); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_cpacked_lowered(), + /* use_string_lookup */ false); } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { - return MakeCallTracePacked(op); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_trace_packed_lowered(), + /* use_string_lookup */ true); + } else if (op->op.same_as(builtin::anylist_setitem_call_packed())) { + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_packed_lowered(), true); + } else if (op->op.same_as(builtin::anylist_setitem_call_cpacked())) { + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_cpacked_lowered(), false); } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { return MakeShape(op); } else if (op->op.same_as(builtin::tvm_stack_make_array())) { @@ -418,8 +426,68 @@ class BuiltinLower : public StmtExprMutator { cast(DataType::Int(32), device_type_))); return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr); } - // call packed. - PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) { + + void SetPackedArg(PrimExpr arg, const Var& value_stack, const Buffer& tcode_stack, + size_t stack_offset, std::vector* prep_seq) { + auto* call_pattern = arg.as(); + if (call_pattern && call_pattern->op.same_as(builtin::anylist_getitem())) { + // call runtime function to set anylist + prep_seq->emplace_back( + Evaluate(Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListSetPackedArg"), + {call_pattern->args[0], call_pattern->args[1], value_stack, + tcode_stack->data, ConstInt32(stack_offset)}))); + } else { + DataType api_type = APIType(arg.dtype()); + if (arg.dtype() != api_type) { + arg = Cast(api_type, arg); + } + prep_seq->emplace_back( + TVMStructSet(value_stack, stack_offset, builtin::kTVMValueContent, arg)); + int arg_tcode = api_type.code(); + if (api_type.is_handle() && arg.as()) { + arg_tcode = kTVMStr; + } else if (IsArrayHandle(arg)) { + arg_tcode = kTVMDLTensorHandle; + } + // opaque handle need to set the kind properly + if (arg_tcode == kTVMOpaqueHandle) { + prep_seq->emplace_back(IfThenElse( + Call(DataType::Bool(), builtin::isnullptr(), {arg}), + BufferStore(tcode_stack, ConstInt32(kTVMNullptr), {ConstInt32(stack_offset)}), + BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)}))); + } else { + prep_seq->emplace_back( + BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)})); + } + } + } + + PrimExpr MakeAnyListSetItemCallPacked(const CallNode* op, const Op& lowered_op, + bool use_string_lookup) { + PrimExpr list_handle = op->args[0]; + PrimExpr list_index = op->args[1]; + + Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup); + PrimExpr value_stack = call->args[1]; + PrimExpr tcode_stack = call->args[2]; + // The stack offset of return value stack_end + PrimExpr ret_offset = call->args[4]; + auto& prep_seq = prep_seq_stack_.back(); + prep_seq.emplace_back(Evaluate(call)); + return Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListMoveFromPackedReturn"), + {list_handle, list_index, value_stack, tcode_stack, ret_offset}); + } + /*! + * \brief Generic tool to make low-level + * packed_call(other_args..., func_name, packed_arg0, packed_arg1...) + * + * \param op The call + * \param name_offset The beginning of function name and call packed section. + * \param lowered_packed_op The target lowered op. + * \param use_string_lookup Whether to lookup function by string. + */ + Call MakeCallPackedGeneric(const CallNode* op, size_t name_offset, const Op& lowered_packed_op, + bool use_string_lookup) { auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); @@ -427,34 +495,24 @@ class BuiltinLower : public StmtExprMutator { size_t restore_array_stack = scope.run_sizes.array_stack; size_t arg_stack_begin = scope.run_sizes.arg_stack; - size_t arg_count = op->args.size(); + size_t args_begin = name_offset + 1; + size_t args_end = op->args.size(); // cpacked expects a resource_handle parameter if (!use_string_lookup) { - arg_count--; + --args_end; } + size_t num_args = args_end - args_begin; - scope.run_sizes.arg_stack += arg_count; + // The extra one slot is for return value. + scope.run_sizes.arg_stack += num_args + 1; // Specially handle the buffer packed intrinsic PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - for (size_t i = 1; i < arg_count; ++i) { - PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); - PrimExpr arg = op->args[i]; - DataType t = arg.dtype(); - DataType api_type = APIType(t); - if (t != api_type) { - arg = Cast(api_type, arg); - } - prep_seq.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); - int arg_tcode = api_type.code(); - if (api_type.is_handle() && arg.as()) { - arg_tcode = kTVMStr; - } - if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; - prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); + + for (size_t i = 0; i < num_args; ++i) { + this->SetPackedArg(op->args[args_begin + i], scope.stack_value, scope.stack_tcode, + arg_stack_begin + i, &prep_seq); } // Verify stack size matches earlier value. if (is_precheck_) { @@ -465,13 +523,12 @@ class BuiltinLower : public StmtExprMutator { scope.run_sizes.shape_stack = restore_shape_stack; scope.run_sizes.array_stack = restore_array_stack; scope.run_sizes.arg_stack = arg_stack_begin; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1)}; - + Array packed_args = {op->args[name_offset], scope.stack_value, + scope.stack_tcode->data, ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + num_args)}; // cpacked call resource_handle if (!use_string_lookup) { - PrimExpr last_arg = op->args[arg_count]; + PrimExpr last_arg = op->args[args_end]; const VarNode* var_node = last_arg.as(); if (var_node != nullptr) { tir::Var resource_handle = GetRef(var_node); @@ -480,57 +537,7 @@ class BuiltinLower : public StmtExprMutator { packed_args.push_back(last_arg); } } - - auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() - : builtin::tvm_call_cpacked_lowered(); - return Call(op->dtype, builtin_call, packed_args); - } - - PrimExpr MakeCallTracePacked(const CallNode* op) { - ICHECK(!alloca_scope_.empty()); - auto& scope = alloca_scope_.back(); - auto& prep_seq = prep_seq_stack_.back(); - - int64_t restore_shape_stack = scope.run_sizes.shape_stack; - size_t restore_array_stack = scope.run_sizes.array_stack; - size_t arg_stack_begin = scope.run_sizes.arg_stack; - scope.run_sizes.arg_stack += op->args.size(); - size_t args_size = op->args.size(); - ICHECK_GT(args_size, 0); - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - for (size_t i = 1; i < op->args.size(); ++i) { - PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); - PrimExpr arg = op->args[i]; - DataType t = arg.dtype(); - DataType api_type = APIType(t); - if (t != api_type) { - arg = Cast(api_type, arg); - } - prep_seq.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); - int arg_tcode = api_type.code(); - ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; - prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); - } - // Verify stack size matches earlier value. - if (is_precheck_) { - scope.UpdateMax(); - } else { - scope.AssertMaxIsValid(); - } - scope.run_sizes.shape_stack = restore_shape_stack; - scope.run_sizes.array_stack = restore_array_stack; - // Update the top of the stack, so we can use more than one - // packed function's arguments with the one stack. - scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1), - // Pass traced value. - op->args[args_size - 1]}; - return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); + return Call(op->dtype, lowered_packed_op, packed_args); } Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 0a881691accc..d57efd8b9992 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -30,7 +30,7 @@ from tvm.script import relax as R, tir as T from tvm.relax.testing.vm import check_saved_func -EXEC_MODE = ["bytecode"] +EXEC_MODE = ["bytecode", "compiled"] @pytest.mark.parametrize("exec_mode", EXEC_MODE) diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 4b79ecf70fa1..600d2456174e 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -28,7 +28,7 @@ from tvm.script import relax as R from tvm.script import tir as T -EXEC_MODE = ["bytecode"] +EXEC_MODE = ["bytecode", "compiled"] def codegen(mod, target, exec_mode="bytecode"): diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py new file mode 100644 index 000000000000..6f3bced38581 --- /dev/null +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test the TIR codegen path of VM compiled mode. + +Restrictions: all shape lowered, explicit allocation. +""" +import tvm +import tvm.testing +from tvm import relax +from tvm.ir import assert_structural_equal +from tvm.script import relax as R +from tvm.script import tir as T + + +def get_tir_mod(mod): + builder = relax.ExecBuilder() + return relax.vm._vmcodegen(builder, mod, exec_mode="compiled") + + +def test_add(): + @tvm.script.ir_module + class Before: + @R.function + def foo(x: R.Tensor): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__foo"}) + T.anylist_setitem_call_packed( + r, + T.int32(2), + "test.vm.add", + T.anylist_getitem(r, T.int32(0)), + T.anylist_getitem(r, T.int32(0)), + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_tir_call(): + @tvm.script.ir_module + class Before: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + T.func_attr({"global_symbol": "shape_func"}) + # generated compute function + H[T.int64(0)] = H[T.int64(0)] + T.int64(1) + + @R.function + def foo(x: R.Tensor): + R.func_attr({"global_symbol": "foo"}) + _ = shape_func(x) + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + T.func_attr({"global_symbol": "shape_func"}) + # generated compute function + H[T.int64(0)] = H[T.int64(0)] + T.int64(1) + + @T.prim_func + def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__foo"}) + T.call_cpacked( + "shape_func", T.anylist_getitem(r, T.int32(0)), T.reinterpret("handle", T.uint64(0)) + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(0)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_if_cond(): + @tvm.script.ir_module + class Before: + @R.function + def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor: + R.func_attr({"global_symbol": "ife"}) + if cond: + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + else: + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + return w + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__ife"}) + if T.cast( + T.tvm_call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))), + "bool", + ): + T.anylist_setitem_call_packed( + r, + T.int32(4), + "test.vm.add", + T.anylist_getitem(r, T.int32(1)), + T.anylist_getitem(r, T.int32(1)), + ) + T.anylist_setitem_call_packed( + r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(4)) + ) + else: + T.anylist_setitem_call_packed( + r, + T.int32(5), + "test.vm.mul", + T.anylist_getitem(r, T.int32(1)), + T.anylist_getitem(r, T.int32(1)), + ) + T.anylist_setitem_call_packed( + r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(5)) + ) + T.anylist_setitem_call_packed( + r, T.int32(2), "vm.builtin.copy", T.anylist_getitem(r, T.int32(3)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_const(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = (y, R.const([3, 4]), x) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + # function attr dict + T.func_attr({"global_symbol": "__vmtir__main"}) + # body + T.anylist_setitem_call_packed( + r, + T.int32(2), + "vm.builtin.make_tuple", + T.anylist_getitem(c, T.int32(0)), + T.anylist_getitem(c, T.int32(1)), + T.anylist_getitem(r, T.int32(0)), + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_const_call(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor)) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + # function attr dict + T.func_attr({"global_symbol": "__vmtir__main"}) + # body + T.anylist_setitem_call_packed( + r, + 2, + "test.vm.add", + T.anylist_getitem(r, 0), + T.anylist_getitem(c, 0), + ) + T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2)) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +if __name__ == "__main__": + tvm.testing.main()