Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,42 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...);
}

template<typename T>
struct func_signature_helper {
using FType = void;
};

template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...)> {
using FType = R(Args...);
};

template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...) const> {
using FType = R(Args...);
};

/*!
* \brief template class to get function signature of a function or functor.
* \tparam T The funtion/functor type.
*/
template<typename T>
struct function_signature {
using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
};

// handle case of function.
template<typename R, typename ...Args>
struct function_signature<R(Args...)> {
using FType = R(Args...);
};

// handle case of function ptr.
template<typename R, typename ...Args>
struct function_signature<R (*)(Args...)> {
using FType = R(Args...);
};
} // namespace detail

/* \brief argument settter to PackedFunc */
Expand Down
59 changes: 23 additions & 36 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <tvm/runtime/packed_func.h>
#include <string>
#include <vector>
#include <utility>

namespace tvm {
namespace runtime {
Expand All @@ -66,28 +67,7 @@ class Registry {
return set_body(PackedFunc(f));
}
/*!
* \brief set the body of the function to be TypedPackedFunc.
*
* \code
*
* TVM_REGISTER_GLOBAL("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; });
*
* \endcode
*
* \param f The body of the function.
* \tparam FType the signature of the function.
* \tparam FLambda The type of f.
*/
template<typename FType, typename FLambda>
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}

/*!
* \brief set the body of the function to the given function pointer.
* Note that this doesn't work with lambdas, you need to
* explicitly give a type for those.
* \brief set the body of the function to the given function.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
Expand All @@ -99,17 +79,20 @@ class Registry {
* TVM_REGISTER_GLOBAL("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* // will have type int(int, int)
* TVM_REGISTER_GLOBAL("sub")
* .set_body_typed([](int a, int b) -> int { return a - b; });
*
* \endcode
*
* \param f The function to forward to.
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
* \tparam FLambda The signature of the function.
*/
template<typename R, typename ...Args>
Registry& set_body_typed(R (*f)(Args...)) {
return set_body(TypedPackedFunc<R(Args...)>(f));
template<typename FLambda>
Registry& set_body_typed(FLambda f) {
using FType = typename detail::function_signature<FLambda>::FType;
return set_body(TypedPackedFunc<FType>(std::move(f)).packed());
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
Expand All @@ -132,10 +115,11 @@ class Registry {
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...)) {
return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
auto fwrap =[f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
};
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap));
}

/*!
Expand All @@ -160,10 +144,11 @@ class Registry {
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
auto fwrap = [f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
};
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap));
}

/*!
Expand Down Expand Up @@ -199,11 +184,12 @@ class Registry {
template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
auto fwrap = [f](TObjectRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
}

/*!
Expand Down Expand Up @@ -239,11 +225,12 @@ class Registry {
template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
auto fwrap = [f](TObjectRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed(DetectClipBound);

TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
.set_body_typed([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
Expand Down
4 changes: 2 additions & 2 deletions src/api/api_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ TVM_REGISTER_GLOBAL("_raw_ptr")
});

TVM_REGISTER_GLOBAL("_save_json")
.set_body_typed<std::string(ObjectRef)>(SaveJSON);
.set_body_typed(SaveJSON);

TVM_REGISTER_GLOBAL("_load_json")
.set_body_typed<ObjectRef(std::string)>(LoadJSON);
.set_body_typed(LoadJSON);

TVM_REGISTER_GLOBAL("_TVMSetStream")
.set_body_typed(TVMSetStream);
Expand Down
44 changes: 22 additions & 22 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace ir {

TVM_REGISTER_GLOBAL("_Var")
.set_body_typed<VarExpr(std::string, DataType)>([](std::string s, DataType t) {
.set_body_typed([](std::string s, DataType t) {
return Variable::make(t, s);
});

Expand Down Expand Up @@ -64,7 +64,7 @@ TVM_REGISTER_GLOBAL("make._range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);

TVM_REGISTER_GLOBAL("make.For")
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
.set_body_typed([](
VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body) {
return For::make(loop_var,
Expand Down Expand Up @@ -99,7 +99,7 @@ TVM_REGISTER_GLOBAL("make.Realize")
.set_body_typed(Realize::make);

TVM_REGISTER_GLOBAL("make.Call")
.set_body_typed<Expr(DataType, std::string, Array<Expr>, int, FunctionRef, int)>([](
.set_body_typed([](
DataType type, std::string name,
Array<Expr> args, int call_type,
FunctionRef func, int value_index
Expand All @@ -116,9 +116,9 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
.set_body_typed(CommReducerNode::make);

// make from two arguments
#define REGISTER_MAKE(Node) \
#define REGISTER_MAKE(Node) \
TVM_REGISTER_GLOBAL("make."#Node) \
.set_body_typed(Node::make); \
.set_body_typed(Node::make); \

REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
Expand Down Expand Up @@ -168,32 +168,32 @@ TVM_REGISTER_GLOBAL("make.Block")

// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed<Stmt(VarExpr, DataType, Array<Expr>, Expr, Stmt)>([](
.set_body_typed([](
VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
){
return Allocate::make(buffer_var, type, extents, condition, body);
});

// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \
.set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) { \
return (Func(a, b)); \
})
.set_body_typed([](Expr a, Expr b) { \
return (Func(a, b)); \
})

#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \
TVM_REGISTER_GLOBAL("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator Expr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator Expr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
} \
})
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator Expr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator Expr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
} \
})


REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
Expand Down Expand Up @@ -224,7 +224,7 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("make._OpIfThenElse")
.set_body_typed<Expr(Expr, Expr, Expr)>([] (Expr cond, Expr true_value, Expr false_value) {
.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) {
return if_then_else(cond, true_value, false_value);
});

Expand Down
32 changes: 12 additions & 20 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,22 @@ TVM_REGISTER_GLOBAL("_Layout")
.set_body_typed(LayoutNode::make);

TVM_REGISTER_GLOBAL("_LayoutIndexOf")
.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});

TVM_REGISTER_GLOBAL("_LayoutFactorOf")
.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis));
});

TVM_REGISTER_GLOBAL("_LayoutNdim")
.set_body_typed<int(Layout)>([](Layout layout) {
.set_body_typed([](Layout layout) -> int {
return layout.ndim();
});

TVM_REGISTER_GLOBAL("_LayoutGetItem")
.set_body_typed<std::string(Layout, int)>([](Layout layout, int idx) {
.set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
Expand Down Expand Up @@ -284,14 +284,12 @@ TVM_REGISTER_GLOBAL("_TensorEqual")
.set_body_method(&Tensor::operator==);

TVM_REGISTER_GLOBAL("_TensorHash")
.set_body_typed<int64_t(Tensor)>([](Tensor tensor) {
.set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor));
});

TVM_REGISTER_GLOBAL("_Placeholder")
.set_body_typed<Tensor(Array<Expr>, DataType, std::string)>([](
Array<Expr> shape, DataType dtype, std::string name
) {
.set_body_typed([](Array<Expr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name);
});

Expand All @@ -311,7 +309,7 @@ TVM_REGISTER_GLOBAL("_HybridOp")
.set_body_typed(HybridOpNode::make);

TVM_REGISTER_GLOBAL("_OpGetOutput")
.set_body_typed<Tensor(Operation, int64_t)>([](Operation op, int64_t output) {
.set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output));
});

Expand All @@ -322,9 +320,7 @@ TVM_REGISTER_GLOBAL("_OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors);

TVM_REGISTER_GLOBAL("_IterVar")
.set_body_typed<IterVar(Range, Var, int, std::string)>([](
Range dom, Var var, int iter_type, std::string thread_tag
) {
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
Expand All @@ -341,25 +337,21 @@ TVM_REGISTER_GLOBAL("_StageBind")
.set_body_method(&Stage::bind);

TVM_REGISTER_GLOBAL("_StageSplitByFactor")
.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
Stage stage, IterVar parent, Expr factor
) {
.set_body_typed([](Stage stage, IterVar parent, Expr factor) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner});
});

TVM_REGISTER_GLOBAL("_StageSplitByNParts")
.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
Stage stage, IterVar parent, Expr nparts
) {
.set_body_typed([](Stage stage, IterVar parent, Expr nparts) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner});
});

TVM_REGISTER_GLOBAL("_StageFuse")
.set_body_typed<IterVar(Stage, Array<IterVar>)>([](Stage stage, Array<IterVar> axes) {
.set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused;
stage.fuse(axes, &fused);
return fused;
Expand All @@ -378,7 +370,7 @@ TVM_REGISTER_GLOBAL("_StageReorder")
.set_body_method(&Stage::reorder);

TVM_REGISTER_GLOBAL("_StageTile")
.set_body_typed<Array<IterVar>(Stage, IterVar, IterVar, Expr, Expr)>([](
.set_body_typed([](
Stage stage,
IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor
Expand Down
Loading