diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5650db6f909c..5219034acd72 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -1031,6 +1031,42 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*) for_each_dispatcher ::run(f, std::forward(args)...); } + +template +struct func_signature_helper { + using FType = void; +}; + +template +struct func_signature_helper { + using FType = R(Args...); +}; + +template +struct func_signature_helper { + using FType = R(Args...); +}; + +/*! + * \brief template class to get function signature of a function or functor. + * \tparam T The funtion/functor type. + */ +template +struct function_signature { + using FType = typename func_signature_helper::FType; +}; + +// handle case of function. +template +struct function_signature { + using FType = R(Args...); +}; + +// handle case of function ptr. +template +struct function_signature { + using FType = R(Args...); +}; } // namespace detail /* \brief argument settter to PackedFunc */ diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index a7e80418b9c4..6faa7b7c84d7 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -46,6 +46,7 @@ #include #include #include +#include namespace tvm { namespace runtime { @@ -66,28 +67,7 @@ class Registry { return set_body(PackedFunc(f)); } /*! - * \brief set the body of the function to be TypedPackedFunc. - * - * \code - * - * TVM_REGISTER_GLOBAL("addone") - * .set_body_typed([](int x) { return x + 1; }); - * - * \endcode - * - * \param f The body of the function. - * \tparam FType the signature of the function. - * \tparam FLambda The type of f. - */ - template - Registry& set_body_typed(FLambda f) { - return set_body(TypedPackedFunc(f).packed()); - } - - /*! - * \brief set the body of the function to the given function pointer. - * Note that this doesn't work with lambdas, you need to - * explicitly give a type for those. + * \brief set the body of the function to the given function. * Note that this will ignore default arg values and always require all arguments to be provided. * * \code @@ -99,17 +79,20 @@ class Registry { * TVM_REGISTER_GLOBAL("multiply") * .set_body_typed(multiply); // will have type int(int, int) * + * // will have type int(int, int) + * TVM_REGISTER_GLOBAL("sub") + * .set_body_typed([](int a, int b) -> int { return a - b; }); + * * \endcode * * \param f The function to forward to. - * \tparam R the return type of the function (inferred). - * \tparam Args the argument types of the function (inferred). + * \tparam FLambda The signature of the function. */ - template - Registry& set_body_typed(R (*f)(Args...)) { - return set_body(TypedPackedFunc(f)); + template + Registry& set_body_typed(FLambda f) { + using FType = typename detail::function_signature::FType; + return set_body(TypedPackedFunc(std::move(f)).packed()); } - /*! * \brief set the body of the function to be the passed method pointer. * Note that this will ignore default arg values and always require all arguments to be provided. @@ -132,10 +115,11 @@ class Registry { */ template Registry& set_body_method(R (T::*f)(Args...)) { - return set_body_typed([f](T target, Args... params) -> R { + auto fwrap =[f](T target, Args... params) -> R { // call method pointer return (target.*f)(params...); - }); + }; + return set_body(TypedPackedFunc(fwrap)); } /*! @@ -160,10 +144,11 @@ class Registry { */ template Registry& set_body_method(R (T::*f)(Args...) const) { - return set_body_typed([f](const T target, Args... params) -> R { + auto fwrap = [f](const T target, Args... params) -> R { // call method pointer return (target.*f)(params...); - }); + }; + return set_body(TypedPackedFunc(fwrap)); } /*! @@ -199,11 +184,12 @@ class Registry { template::value>::type> Registry& set_body_method(R (TNode::*f)(Args...)) { - return set_body_typed([f](TObjectRef ref, Args... params) { + auto fwrap = [f](TObjectRef ref, Args... params) { TNode* target = ref.operator->(); // call method pointer return (target->*f)(params...); - }); + }; + return set_body(TypedPackedFunc(fwrap)); } /*! @@ -239,11 +225,12 @@ class Registry { template::value>::type> Registry& set_body_method(R (TNode::*f)(Args...) const) { - return set_body_typed([f](TObjectRef ref, Args... params) { + auto fwrap = [f](TObjectRef ref, Args... params) { const TNode* target = ref.operator->(); // call method pointer return (target->*f)(params...); - }); + }; + return set_body(TypedPackedFunc(fwrap)); } /*! diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 5eef8dbe4178..a69fd5d436df 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("arith.DetectClipBound") .set_body_typed(DetectClipBound); TVM_REGISTER_GLOBAL("arith.DeduceBound") -.set_body_typed, Map)>([]( +.set_body_typed([]( Expr v, Expr cond, const Map hint_map, const Map relax_map diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 89dd4fc795a6..aa5fb784e968 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -45,10 +45,10 @@ TVM_REGISTER_GLOBAL("_raw_ptr") }); TVM_REGISTER_GLOBAL("_save_json") -.set_body_typed(SaveJSON); +.set_body_typed(SaveJSON); TVM_REGISTER_GLOBAL("_load_json") -.set_body_typed(LoadJSON); +.set_body_typed(LoadJSON); TVM_REGISTER_GLOBAL("_TVMSetStream") .set_body_typed(TVMSetStream); diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 2b7a36f9757b..2987b9ef39c1 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -32,7 +32,7 @@ namespace tvm { namespace ir { TVM_REGISTER_GLOBAL("_Var") -.set_body_typed([](std::string s, DataType t) { +.set_body_typed([](std::string s, DataType t) { return Variable::make(t, s); }); @@ -64,7 +64,7 @@ TVM_REGISTER_GLOBAL("make._range_by_min_extent") .set_body_typed(Range::make_by_min_extent); TVM_REGISTER_GLOBAL("make.For") -.set_body_typed([]( +.set_body_typed([]( VarExpr loop_var, Expr min, Expr extent, int for_type, int device_api, Stmt body) { return For::make(loop_var, @@ -99,7 +99,7 @@ TVM_REGISTER_GLOBAL("make.Realize") .set_body_typed(Realize::make); TVM_REGISTER_GLOBAL("make.Call") -.set_body_typed, int, FunctionRef, int)>([]( +.set_body_typed([]( DataType type, std::string name, Array args, int call_type, FunctionRef func, int value_index @@ -116,9 +116,9 @@ TVM_REGISTER_GLOBAL("make.CommReducer") .set_body_typed(CommReducerNode::make); // make from two arguments -#define REGISTER_MAKE(Node) \ +#define REGISTER_MAKE(Node) \ TVM_REGISTER_GLOBAL("make."#Node) \ - .set_body_typed(Node::make); \ + .set_body_typed(Node::make); \ REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); @@ -168,32 +168,32 @@ TVM_REGISTER_GLOBAL("make.Block") // has default args TVM_REGISTER_GLOBAL("make.Allocate") - .set_body_typed, Expr, Stmt)>([]( + .set_body_typed([]( VarExpr buffer_var, DataType type, Array extents, Expr condition, Stmt body ){ return Allocate::make(buffer_var, type, extents, condition, body); }); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("make."#Node) \ - .set_body_typed([](Expr a, Expr b) { \ - return (Func(a, b)); \ - }) + .set_body_typed([](Expr a, Expr b) { \ + return (Func(a, b)); \ + }) #define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("make."#Node) \ + TVM_REGISTER_GLOBAL("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ - bool lhs_is_int = args[0].type_code() == kDLInt; \ - bool rhs_is_int = args[1].type_code() == kDLInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].operator int(), args[1].operator Expr())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].operator Expr(), args[1].operator int())); \ - } else { \ - *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \ - } \ - }) + bool lhs_is_int = args[0].type_code() == kDLInt; \ + bool rhs_is_int = args[1].type_code() == kDLInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].operator int(), args[1].operator Expr())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].operator Expr(), args[1].operator int())); \ + } else { \ + *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \ + } \ + }) REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); @@ -224,7 +224,7 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); TVM_REGISTER_GLOBAL("make._OpIfThenElse") -.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) { +.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) { return if_then_else(cond, true_value, false_value); }); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index b8f7d0fef028..804d8f1f9c51 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -236,22 +236,22 @@ TVM_REGISTER_GLOBAL("_Layout") .set_body_typed(LayoutNode::make); TVM_REGISTER_GLOBAL("_LayoutIndexOf") -.set_body_typed([](Layout layout, std::string axis) { +.set_body_typed([](Layout layout, std::string axis) -> int { return layout.IndexOf(LayoutAxis::make(axis)); }); TVM_REGISTER_GLOBAL("_LayoutFactorOf") -.set_body_typed([](Layout layout, std::string axis) { +.set_body_typed([](Layout layout, std::string axis) -> int { return layout.FactorOf(LayoutAxis::make(axis)); }); TVM_REGISTER_GLOBAL("_LayoutNdim") -.set_body_typed([](Layout layout) { +.set_body_typed([](Layout layout) -> int { return layout.ndim(); }); TVM_REGISTER_GLOBAL("_LayoutGetItem") -.set_body_typed([](Layout layout, int idx) { +.set_body_typed([](Layout layout, int idx) -> std::string { const LayoutAxis& axis = layout[idx]; return axis.name(); }); @@ -284,14 +284,12 @@ TVM_REGISTER_GLOBAL("_TensorEqual") .set_body_method(&Tensor::operator==); TVM_REGISTER_GLOBAL("_TensorHash") -.set_body_typed([](Tensor tensor) { +.set_body_typed([](Tensor tensor) -> int64_t { return static_cast(std::hash()(tensor)); }); TVM_REGISTER_GLOBAL("_Placeholder") -.set_body_typed, DataType, std::string)>([]( - Array shape, DataType dtype, std::string name -) { +.set_body_typed([](Array shape, DataType dtype, std::string name) { return placeholder(shape, dtype, name); }); @@ -311,7 +309,7 @@ TVM_REGISTER_GLOBAL("_HybridOp") .set_body_typed(HybridOpNode::make); TVM_REGISTER_GLOBAL("_OpGetOutput") -.set_body_typed([](Operation op, int64_t output) { +.set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); @@ -322,9 +320,7 @@ TVM_REGISTER_GLOBAL("_OpInputTensors") .set_body_method(&OperationNode::InputTensors); TVM_REGISTER_GLOBAL("_IterVar") -.set_body_typed([]( - Range dom, Var var, int iter_type, std::string thread_tag -) { +.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { return IterVarNode::make( dom, var, static_cast(iter_type), @@ -341,25 +337,21 @@ TVM_REGISTER_GLOBAL("_StageBind") .set_body_method(&Stage::bind); TVM_REGISTER_GLOBAL("_StageSplitByFactor") -.set_body_typed(Stage, IterVar, Expr)>([]( - Stage stage, IterVar parent, Expr factor -) { +.set_body_typed([](Stage stage, IterVar parent, Expr factor) { IterVar outer, inner; stage.split(parent, factor, &outer, &inner); return Array({outer, inner}); }); TVM_REGISTER_GLOBAL("_StageSplitByNParts") -.set_body_typed(Stage, IterVar, Expr)>([]( - Stage stage, IterVar parent, Expr nparts -) { +.set_body_typed([](Stage stage, IterVar parent, Expr nparts) { IterVar outer, inner; stage.split_by_nparts(parent, nparts, &outer, &inner); return Array({outer, inner}); }); TVM_REGISTER_GLOBAL("_StageFuse") -.set_body_typed)>([](Stage stage, Array axes) { +.set_body_typed([](Stage stage, Array axes) { IterVar fused; stage.fuse(axes, &fused); return fused; @@ -378,7 +370,7 @@ TVM_REGISTER_GLOBAL("_StageReorder") .set_body_method(&Stage::reorder); TVM_REGISTER_GLOBAL("_StageTile") -.set_body_typed(Stage, IterVar, IterVar, Expr, Expr)>([]( +.set_body_typed([]( Stage stage, IterVar x_parent, IterVar y_parent, Expr x_factor, Expr y_factor diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 7390e8bb61a3..16c2b1bf44d9 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -95,21 +95,21 @@ TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") }); TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") -.set_body_typed&)> +.set_body_typed ([](const Stmt& stmt, const Schedule& schedule, const Map& extern_buffer) { return RewriteForTensorCore(stmt, schedule, extern_buffer); }); TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual") -.set_body_typed( +.set_body_typed( [](const ObjectRef& lhs, const ObjectRef& rhs) { return AttrsEqual()(lhs, rhs); }); TVM_REGISTER_GLOBAL("ir_pass.AttrsHash") -.set_body_typed([](const ObjectRef &node) { +.set_body_typed([](const ObjectRef &node) -> int64_t { return AttrsHash()(node); - }); +}); TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") diff --git a/src/api/api_test.cc b/src/api/api_test.cc index d57a4e9348f8..de37111c9f92 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -106,7 +106,7 @@ void ErrorTest(int x, int y) { } TVM_REGISTER_GLOBAL("_ErrorTest") -.set_body_typed(ErrorTest); +.set_body_typed(ErrorTest); // internal function used for debug and testing purposes TVM_REGISTER_GLOBAL("_ndarray_use_count") diff --git a/src/ir/type.cc b/src/ir/type.cc index ef5f75b86a2c..bfb44e817a95 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -36,7 +36,7 @@ TypeVar TypeVarNode::make(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_GLOBAL("relay._make.TypeVar") -.set_body_typed([](std::string name, int kind) { +.set_body_typed([](std::string name, int kind) { return TypeVarNode::make(name, static_cast(kind)); }); @@ -57,7 +57,7 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") -.set_body_typed([](std::string name, int kind) { +.set_body_typed([](std::string name, int kind) { return GlobalTypeVarNode::make(name, static_cast(kind)); }); diff --git a/src/node/env_func.cc b/src/node/env_func.cc index 52bb61d7517c..e963576fab93 100644 --- a/src/node/env_func.cc +++ b/src/node/env_func.cc @@ -51,7 +51,7 @@ EnvFunc EnvFunc::Get(const std::string& name) { } TVM_REGISTER_GLOBAL("_EnvFuncGet") -.set_body_typed(EnvFunc::Get); +.set_body_typed(EnvFunc::Get); TVM_REGISTER_GLOBAL("_EnvFuncCall") .set_body([](TVMArgs args, TVMRetValue* rv) { @@ -63,7 +63,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall") }); TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc") -.set_body_typed([](const EnvFunc&n) { +.set_body_typed([](const EnvFunc&n) { return n->func; }); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ae993e90ad08..96cd5a15ff31 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -816,43 +816,43 @@ const CompileEngine& CompileEngine::Global() { } TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") -.set_body_typed(CCacheKeyNode::make); +.set_body_typed(CCacheKeyNode::make); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") -.set_body_typed([]() { +.set_body_typed([]() { return CompileEngine::Global(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") -.set_body_typed([](CompileEngine self) { +.set_body_typed([](CompileEngine self) { self->Clear(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") -.set_body_typed( +.set_body_typed( [](CompileEngine self, CCacheKey key) { return self->Lower(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") -.set_body_typed( +.set_body_typed( [](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") -.set_body_typed([](CompileEngine self) { +.set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") -.set_body_typed( +.set_body_typed( [](CompileEngine self, CCacheKey key) { return self->JIT(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") -.set_body_typed(CompileEngine)>( +.set_body_typed( [](CompileEngine self){ return static_cast(self.operator->())->ListItems(); }); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 71aa5cf6634a..ecb2d59aee6f 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -390,7 +390,7 @@ Map > GraphPlanMemory(const Function& func) { } TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory") -.set_body_typed >(const Function&)>(GraphPlanMemory); +.set_body_typed(GraphPlanMemory); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 7fe39db01a04..248f06bbe508 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -595,23 +595,23 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { // TODO(@jroesch): move to correct namespace? TVM_REGISTER_GLOBAL("relay._make._alpha_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(false, false).Equal(a, b); }); TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; }); TVM_REGISTER_GLOBAL("relay._make._graph_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(true, false).Equal(a, b); }); TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal") -.set_body_typed([](ObjectRef a, ObjectRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; }); diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 176ee0829c96..4bac1fd1f626 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -35,7 +35,7 @@ using namespace tvm::runtime; TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_GLOBAL("relay._base.set_span") -.set_body_typed([](ObjectRef node_ref, Span sp) { +.set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { CHECK(rn); rn->span = sp; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 11689b079c67..27381b88dc7b 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -167,7 +167,7 @@ Function FunctionNode::SetParams(const tvm::Map& parameters) cons } TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams") -.set_body_typed&)>( +.set_body_typed( [](const Function& func, const tvm::Map& parameters) { return func->SetParams(parameters); }); @@ -178,7 +178,7 @@ tvm::Map FunctionNode::GetParams() const { } TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams") -.set_body_typed(const Function&)>([](const Function& func) { +.set_body_typed([](const Function& func) { return func->GetParams(); }); @@ -367,12 +367,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") -.set_body_typed([](TempExpr temp) { +.set_body_typed([](TempExpr temp) { return temp->Realize(); }); TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") -.set_body_typed( +.set_body_typed( [](Function func, std::string name, ObjectRef ref) { return FunctionSetAttr(func, name, ref); }); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index d6e4d4192d49..0da763ab4083 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -348,7 +348,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { } TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit") -.set_body_typed([](Expr expr, PackedFunc f) { +.set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); }); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index b9406669ddfc..cf1e280d6b9e 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -424,12 +424,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { } TVM_REGISTER_GLOBAL("relay._analysis._expr_hash") -.set_body_typed([](ObjectRef ref) { +.set_body_typed([](ObjectRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); TVM_REGISTER_GLOBAL("relay._analysis._type_hash") -.set_body_typed([](Type type) { +.set_body_typed([](Type type) { return static_cast(RelayHashHandler().TypeHash(type)); }); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 4e57258981c6..e663168ff59f 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -318,7 +318,7 @@ Module FromText(const std::string& source, const std::string& source_name) { TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_GLOBAL("relay._make.Module") -.set_body_typed, tvm::Map)>( +.set_body_typed( [](tvm::Map funcs, tvm::Map types) { return ModuleNode::make(funcs, types, {}); }); @@ -365,52 +365,49 @@ TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar") .set_body_method(&ModuleNode::GetGlobalTypeVar); TVM_REGISTER_GLOBAL("relay._module.Module_Lookup") -.set_body_typed([](Module mod, GlobalVar var) { +.set_body_typed([](Module mod, GlobalVar var) { return mod->Lookup(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") -.set_body_typed([](Module mod, std::string var) { +.set_body_typed([](Module mod, std::string var) { return mod->Lookup(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") -.set_body_typed([](Module mod, GlobalTypeVar var) { +.set_body_typed([](Module mod, GlobalTypeVar var) { return mod->LookupDef(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") -.set_body_typed([](Module mod, std::string var) { +.set_body_typed([](Module mod, std::string var) { return mod->LookupDef(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") -.set_body_typed([](Module mod, int32_t tag) { +.set_body_typed([](Module mod, int32_t tag) { return mod->LookupTag(tag); }); TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") -.set_body_typed< - Module(Expr, - tvm::Map, - tvm::Map)>([](Expr e, - tvm::Map funcs, - tvm::Map type_defs) { - return ModuleNode::FromExpr(e, funcs, type_defs); - }); +.set_body_typed([](Expr e, + tvm::Map funcs, + tvm::Map type_defs) { + return ModuleNode::FromExpr(e, funcs, type_defs); +}); TVM_REGISTER_GLOBAL("relay._module.Module_Update") -.set_body_typed([](Module mod, Module from) { +.set_body_typed([](Module mod, Module from) { mod->Update(from); }); TVM_REGISTER_GLOBAL("relay._module.Module_Import") -.set_body_typed([](Module mod, std::string path) { +.set_body_typed([](Module mod, std::string path) { mod->Import(path); }); TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") -.set_body_typed([](Module mod, std::string path) { +.set_body_typed([](Module mod, std::string path) { mod->ImportFromStd(path); });; diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 4bef724957de..9bea870eda93 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -136,7 +136,7 @@ void OpRegistry::UpdateAttr(const std::string& key, // Frontend APIs TVM_REGISTER_GLOBAL("relay.op._ListOpNames") -.set_body_typed()>([]() { +.set_body_typed([]() { Array ret; for (const std::string& name : dmlc::Registry::ListAllNames()) { @@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames") return ret; }); -TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get); +TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get); TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") .set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 612d586b9dd5..362bbf0bb659 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -991,9 +991,7 @@ std::string AsText(const ObjectRef& node, } TVM_REGISTER_GLOBAL("relay._expr.AsText") -.set_body_typed)>(AsText); +.set_body_typed(AsText); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index aa9d37649586..a7f275d5361b 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -91,7 +91,7 @@ IncompleteType IncompleteTypeNode::make(Kind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_GLOBAL("relay._make.IncompleteType") -.set_body_typed([](int kind) { +.set_body_typed([](int kind) { return IncompleteTypeNode::make(static_cast(kind)); }); @@ -161,8 +161,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_REGISTER_GLOBAL("relay._make.Any") -.set_body_typed([]() { return Any::make(); }); - +.set_body_typed([]() { return Any::make(); }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 253af5bd3a5b..efcb383d5e9d 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -40,7 +40,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") -.set_body_typed([](Expr data, int device_type) { +.set_body_typed([](Expr data, int device_type) { auto attrs = make_object(); attrs->device_type = device_type; static const Op& op = Op::Get("on_device"); @@ -63,7 +63,7 @@ Expr StopFusion(Expr data) { } TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion") -.set_body_typed([](Expr data) { +.set_body_typed([](Expr data) { return StopFusion(data); }); @@ -145,7 +145,7 @@ Mark the end of bitpacking. }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint") -.set_body_typed([](Expr data) { +.set_body_typed([](Expr data) { static const Op& op = Op::Get("annotation.checkpoint"); return CallNode::make(op, {data}, Attrs{}, {}); }); diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 9c3f6afa98ea..463b76f7046d 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -42,7 +42,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); TVM_REGISTER_GLOBAL("relay.op._make.device_copy") -.set_body_typed([](Expr data, int src_dev_type, +.set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) { auto attrs = make_object(); attrs->src_dev_type = src_dev_type; diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index af7291de4058..42f016d64b8f 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -42,7 +42,7 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); // We should consider a better solution, i.e the type relation // being able to see the arguments as well? TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage") - .set_body_typed([](Expr size, Expr alignment, DataType dtype) { + .set_body_typed([](Expr size, Expr alignment, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("memory.alloc_storage"); @@ -88,7 +88,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") }); TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") - .set_body_typed assert_shape)>( + .set_body_typed( [](Expr storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); attrs->dtype = dtype; @@ -209,7 +209,7 @@ bool InvokeTVMOPRel(const Array& types, int num_inputs, const Attrs& attrs } TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op") - .set_body_typed( + .set_body_typed( [](Expr func, Expr inputs, Expr outputs) { return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs()); }); @@ -257,7 +257,7 @@ RELAY_REGISTER_OP("memory.kill") }); TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func") - .set_body_typed)>( + .set_body_typed( [](Expr func, Expr inputs, Expr outputs, Array is_input) { static const Op& op = Op::Get("memory.shape_func"); auto attrs = make_object(); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 50cde43c3c79..bab59f7b246b 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -326,7 +326,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") -.set_body_typed([](Expr data, int axis) { +.set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); @@ -361,7 +361,7 @@ RELAY_REGISTER_OP("nn.softmax") // relay.nn.log_softmax TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") -.set_body_typed([](Expr data, int axis) { +.set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); @@ -470,7 +470,7 @@ Example:: // relu TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") -.set_body_typed([](Expr data) { +.set_body_typed([](Expr data) { static const Op& op = Op::Get("nn.relu"); return CallNode::make(op, {data}, Attrs(), {}); }); diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 88d43062792c..529435d24494 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -214,13 +214,12 @@ Array Pool2DCompute(const Attrs& attrs, } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") -.set_body_typed, Array, Array, - std::string, bool)>([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { +.set_body_typed([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, "nn.max_pool2d"); }); @@ -258,14 +257,13 @@ RELAY_REGISTER_OP("nn.max_pool2d") // AvgPool2D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") -.set_body_typed, Array, Array, - std::string, bool, bool)>([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { +.set_body_typed([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode, + bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad, "nn.avg_pool2d"); }); @@ -868,13 +866,12 @@ Array Pool3DCompute(const Attrs& attrs, } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") -.set_body_typed, Array, Array, - std::string, bool)>([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { +.set_body_typed([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode) { return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, "nn.max_pool3d"); }); @@ -912,14 +909,13 @@ RELAY_REGISTER_OP("nn.max_pool3d") // AvgPool3D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") -.set_body_typed, Array, Array, - std::string, bool, bool)>([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { +.set_body_typed([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode, + bool count_include_pad) { return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad, "nn.avg_pool3d"); }); diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index fc8978b94b75..d9846cf443f7 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -48,19 +48,19 @@ namespace relay { * \param OpName the name of registry. */ #define RELAY_REGISTER_UNARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr data) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {data}, Attrs(), {}); \ - }); \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") \ - .add_type_rel("Identity", IdentityRel) \ - .set_attr("TOpPattern", kElemWise) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - ElemwiseArbitraryLayout) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") \ + .add_type_rel("Identity", IdentityRel) \ + .set_attr("TOpPattern", kElemWise) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + ElemwiseArbitraryLayout) \ /*! Quick helper macro @@ -73,38 +73,38 @@ namespace relay { * * \param OpName the name of registry. */ -#define RELAY_REGISTER_BINARY_OP(OpName) \ +#define RELAY_REGISTER_BINARY_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("Broadcast", BroadcastRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) + .set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("Broadcast", BroadcastRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + BinaryBroadcastLayout) // Comparisons -#define RELAY_REGISTER_CMP_OP(OpName) \ +#define RELAY_REGISTER_CMP_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("BroadcastComp", BroadcastCompRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) + .set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("BroadcastComp", BroadcastCompRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", \ + BinaryBroadcastLayout) /*! \brief A helper class for matching and rewriting operators. */ diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 07f1f56d20e3..ae8e62c49980 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -303,7 +303,7 @@ bool ReduceRel(const Array& types, #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed, bool, bool)>([]( \ + .set_body_typed([]( \ Expr data, \ Array axis, \ bool keepdims, \ diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1d56a0fe3f77..ca40d2a71915 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -858,7 +858,7 @@ bool ArgWhereRel(const Array& types, } TVM_REGISTER_GLOBAL("relay.op._make.argwhere") -.set_body_typed([](Expr data) { +.set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); auto attrs = make_object(); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index cc8419c36733..5189f7d97ec3 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -158,7 +158,7 @@ RELAY_REGISTER_UNARY_OP("copy") TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_GLOBAL("relay.op._make.clip") -.set_body_typed([](Expr a, double a_min, double a_max) { +.set_body_typed([](Expr a, double a_min, double a_max) { auto attrs = make_object(); attrs->a_min = a_min; attrs->a_max = a_max; @@ -301,7 +301,7 @@ Array ShapeOfCompute(const Attrs& attrs, } TVM_REGISTER_GLOBAL("relay.op._make.shape_of") -.set_body_typed([](Expr data, DataType dtype) { +.set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); @@ -352,7 +352,7 @@ Array NdarraySizeCompute(const Attrs& attrs, } TVM_REGISTER_GLOBAL("relay.op.contrib._make.ndarray_size") -.set_body_typed([](Expr data, DataType dtype) { +.set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("contrib.ndarray_size"); diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index d9e8d87ed15d..2392b18e768f 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -327,7 +327,7 @@ Array UnmatchedCases(const Match& match, const Module& mod) { // expose for testing only TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") -.set_body_typed(const Match&, const Module&)>( +.set_body_typed( [](const Match& match, const Module& mod_ref) { Module call_mod = mod_ref; if (!call_mod.defined()) { diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index c995994757a7..b4d6926ef0da 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -67,7 +67,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") -.set_body_typed( +.set_body_typed( [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, std::string rounding) { auto attrs = make_object(); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index a2944a9d4abd..78101bcf045a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -79,7 +79,7 @@ bool TupleGetItemRel(const Array& types, TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem") -.set_body_typed&, int, const Attrs&, const TypeReporter&)>( +.set_body_typed( TupleGetItemRel); struct ResolvedTypeInfo { @@ -840,7 +840,7 @@ Pass InferType() { } TVM_REGISTER_GLOBAL("relay._transform.InferType") -.set_body_typed([]() { +.set_body_typed([]() { return InferType(); }); diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 2a330098e074..444ca60327a1 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -63,19 +63,18 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con * * \param OpName the name of registry. */ -#define QNN_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ - .set_body_typed( \ - [](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ - Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ - static const Op& op = Op::Get("qnn." OpName); \ - return CallNode::make(op, {lhs, rhs, \ - lhs_scale, lhs_zero_point, \ - rhs_scale, rhs_zero_point, \ - output_scale, output_zero_point}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP("qnn." OpName) \ - .set_num_inputs(8) \ +#define QNN_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ + .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ + Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ + static const Op& op = Op::Get("qnn." OpName); \ + return CallNode::make(op, {lhs, rhs, \ + lhs_scale, lhs_zero_point, \ + rhs_scale, rhs_zero_point, \ + output_scale, output_zero_point}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP("qnn." OpName) \ + .set_num_inputs(8) \ .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index f02fadb53ed9..3714425a3323 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -795,7 +795,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") }); TVM_REGISTER_GLOBAL("relay._vm.Load_Executable") -.set_body_typed([]( +.set_body_typed([]( std::string code, runtime::Module lib) { return Executable::Load(code, lib); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index a9b9b83ee250..fba8e10d4c35 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -179,6 +179,24 @@ TEST(TypedPackedFunc, HighOrder) { CHECK_EQ(f1(3), 4); } +TEST(TypedPackedFunc, Deduce) { + using namespace tvm::runtime; + using tvm::runtime::detail::function_signature; + + TypedPackedFunc x; + auto f = [](int x) -> int { + return x + 1; + }; + std::function y; + + static_assert(std::is_same::FType, + int(float)>::value, "invariant1"); + static_assert(std::is_same::FType, + int(int)>::value, "invariant2"); + static_assert(std::is_same::FType, + void(float)>::value, "invariant3"); +} + TEST(PackedFunc, ObjectConversion) { using namespace tvm; diff --git a/web/web_runtime.cc b/web/web_runtime.cc index d5b40889472c..701ded76288e 100644 --- a/web/web_runtime.cc +++ b/web/web_runtime.cc @@ -60,13 +60,13 @@ struct RPCEnv { }; TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body_typed([](std::string path) { +.set_body_typed([](std::string path) { static RPCEnv env; return env.GetPath(path); }); TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body_typed([](std::string path) { +.set_body_typed([](std::string path) { std::string file_name = "/rpc/" + path; LOG(INFO) << "Load module from " << file_name << " ..."; return Module::LoadFromFile(file_name, "");