From 327ad3d9255194b4dbcd964524d9c74bebe899e4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 11 Sep 2020 07:12:36 +0000 Subject: [PATCH 1/2] initial commit --- example/extensions/lib_custom_op/gemm_lib.cc | 3 ++ example/extensions/lib_custom_op/relu_lib.cu | 6 ++++ .../lib_custom_op/transposecsr_lib.cc | 3 ++ .../lib_custom_op/transposerowsp_lib.cc | 3 ++ .../extensions/lib_subgraph/subgraph_lib.cc | 3 ++ include/mxnet/lib_api.h | 13 +++++--- src/c_api/c_api.cc | 30 +++++++++++++++++-- src/lib_api.cc | 23 ++++++++++++-- 8 files changed, 75 insertions(+), 9 deletions(-) diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc index f8d1d326a008..7ebd781d5a0e 100644 --- a/example/extensions/lib_custom_op/gemm_lib.cc +++ b/example/extensions/lib_custom_op/gemm_lib.cc @@ -203,6 +203,9 @@ class MyStatefulGemm : public CustomStatefulOp { }; MXReturnValue createOpState(const std::unordered_map& attrs, + const MXContext& ctx, + const std::vector >& in_shapes, + const std::vector in_types, CustomStatefulOp** op_inst) { // testing passing of keyword arguments int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0; diff --git a/example/extensions/lib_custom_op/relu_lib.cu b/example/extensions/lib_custom_op/relu_lib.cu index e4aa8a3decc3..8a0a92e158a7 100644 --- a/example/extensions/lib_custom_op/relu_lib.cu +++ b/example/extensions/lib_custom_op/relu_lib.cu @@ -168,12 +168,18 @@ class MyStatefulReluGPU : public CustomStatefulOp { }; MXReturnValue createOpStateCPU(const std::unordered_map& attrs, + const MXContext& ctx, + const std::vector >& in_shapes, + const std::vector in_types, CustomStatefulOp** op_inst) { *op_inst = new MyStatefulReluCPU(attrs); return MX_SUCCESS; } MXReturnValue createOpStateGPU(const std::unordered_map& attrs, + const MXContext& ctx, + const std::vector >& in_shapes, + const std::vector in_types, CustomStatefulOp** op_inst) { *op_inst = new MyStatefulReluGPU(attrs); return MX_SUCCESS; diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc b/example/extensions/lib_custom_op/transposecsr_lib.cc index 80053ec55cd5..5ea3fbcd7908 100644 --- a/example/extensions/lib_custom_op/transposecsr_lib.cc +++ b/example/extensions/lib_custom_op/transposecsr_lib.cc @@ -175,6 +175,9 @@ class MyStatefulTransposeCSR : public CustomStatefulOp { }; MXReturnValue createOpState(const std::unordered_map& attrs, + const MXContext& ctx, + const std::vector >& in_shapes, + const std::vector in_types, CustomStatefulOp** op_inst) { // testing passing of keyword arguments int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0; diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc index d6addb39c4d3..07ab653f7ad7 100644 --- a/example/extensions/lib_custom_op/transposerowsp_lib.cc +++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc @@ -177,6 +177,9 @@ class MyStatefulTransposeRowSP : public CustomStatefulOp { }; MXReturnValue createOpState(const std::unordered_map& attrs, + const MXContext& ctx, + const std::vector >& in_shapes, + const std::vector in_types, CustomStatefulOp** op_inst) { // testing passing of keyword arguments int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0; diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc index 51ff77456be1..f47109343007 100644 --- a/example/extensions/lib_subgraph/subgraph_lib.cc +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -155,6 +155,9 @@ class MyStatefulOp : public CustomStatefulOp { }; MXReturnValue createOpState(const std::unordered_map& attrs, + const MXContext& ctx, + const std::vector >& in_shapes, + const std::vector in_types, CustomStatefulOp** op_inst) { std::string serialized_subgraph = "[empty]"; // MXNet subgraph is stored as Symbol in operator node attrs subgraphs field diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 1ac45ba81b49..ad610ce0bed7 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -732,6 +732,9 @@ typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map* input_indices); typedef MXReturnValue (*createOpState_t)(const std::unordered_map& attributes, + const MXContext& ctx, + const std::vector >& in_shapes, + const std::vector in_types, CustomStatefulOp**); /*! @@ -1000,8 +1003,9 @@ typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* ke #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState" typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys, - const char* const* vals, int num, - void** state_op); + const char* const* vals, int num, const char* dev_type, + int dev_id, unsigned int** inshapes, int* indims, + int num_in, const int* intypes, void** state_op); #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op, @@ -1190,8 +1194,9 @@ extern "C" { /*! \brief returns status of calling createStatefulOp function for operator from library */ MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys, - const char* const* vals, int num, - void** state_op); + const char* const* vals, int num, const char* dev_type, + int dev_id, unsigned int** inshapes, int* indims, + int num_in, const int* intypes, void** state_op); /*! \brief returns status of calling Stateful Forward/Backward for operator from library */ MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes, diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5854a686f583..5e91bccde99e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1092,6 +1092,28 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, attr_vals.push_back(kv.second.c_str()); } + // string repr of supported context for custom library, currently only "cpu" and "gpu" + const char* ctx_str = ctx.dev_mask() == Context::kCPU ? "cpu" : "gpu"; + + std::vector inshapes(in_shapes.size()); + std::vector indims(in_shapes.size()); + + // determine amount of memory needed to store all the input shapes + size_t buff_size = 0; + for (size_t i = 0; i < in_shapes.size(); ++i) + buff_size += in_shapes[i].ndim(); + + // copy input shapes to raw memory layout + std::vector inbuff(buff_size); + uint32_t *ptr = inbuff.data(); + for (size_t i = 0; i < in_shapes.size(); ++i) { + inshapes[i] = ptr; + indims[i] = in_shapes[i].ndim(); + for (int j = 0; j < in_shapes[i].ndim(); ++j, ++ptr) { + *ptr = static_cast(in_shapes[i][j]); + } + } + // convert subgraph symbol from node attributes to char* std::string subgraph_json; if (!attrs.subgraphs.empty()) { @@ -1110,7 +1132,9 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, CHECK(createop_map.count("cpu") > 0) << "CPU CreateOpState not implemented for '" << name_str << "'"; int retval = callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(), - attr_keys.size(), &state_op_inst); + attr_keys.size(), ctx_str, ctx.real_dev_id(), + inshapes.data(), indims.data(), + in_shapes.size(), in_types.data(), &state_op_inst); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'" << msgs; @@ -1118,7 +1142,9 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, CHECK(createop_map.count("gpu") > 0) << "GPU CreateOpState not implemented for '" << name_str << "'"; int retval = callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(), - attr_keys.size(), &state_op_inst); + attr_keys.size(), ctx_str, ctx.real_dev_id(), + inshapes.data(), indims.data(), + in_shapes.size(), in_types.data(), &state_op_inst); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'" << msgs; diff --git a/src/lib_api.cc b/src/lib_api.cc index 8255095255f5..20ae280acf6c 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -1210,19 +1210,36 @@ MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* co /*! \brief returns status of calling createStatefulOp function for operator from library */ MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys, - const char* const* vals, int num, - void** state_op) { + const char* const* vals, int num, const char* dev_type, + int dev_id, unsigned int** inshapes, int* indims, + int num_in, const int* intypes, void** state_op) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { attrs[std::string(keys[i])] = std::string(vals[i]); } + mxnet::ext::MXContext ctx(dev_type, dev_id); + + // create a vector of shapes for inputs + std::vector > in_shapes(num_in); + for (int i = 0; i < num_in; i++) { + for (int j = 0; j < indims[i]; j++) { + in_shapes[i].push_back(inshapes[i][j]); + } + } + + // create a vector of types for inputs + std::vector in_types(num_in); + for (int i = 0; i < num_in; i++) { + in_types[i] = intypes[i]; + } + // void pointer to hold custom state op instance created in custom library // eventually state_op pointer is populated by instance from custom library mxnet::ext::CustomStatefulOp** op_ptr = reinterpret_cast(state_op); - return create_op(attrs, op_ptr); + return create_op(attrs, ctx, in_shapes, in_types, op_ptr); } /*! \brief returns status of calling Stateful Forward/Backward for operator from library */ From 35feed663427ceafd3369a55c47d7140476851ea Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 13 Sep 2020 03:48:19 +0000 Subject: [PATCH 2/2] incremented version number --- include/mxnet/lib_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index ad610ce0bed7..0213557fdc92 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -53,7 +53,7 @@ #endif /* Make sure to update the version number everytime you make changes */ -#define MX_LIBRARY_VERSION 9 +#define MX_LIBRARY_VERSION 10 /*! * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple