From 89daa47550f7e74abf7b2028896074a838bea7e1 Mon Sep 17 00:00:00 2001 From: samskalicky Date: Tue, 11 Feb 2020 09:37:07 +0000 Subject: [PATCH] initial commit --- include/mxnet/lib_api.h | 16 ++++++--- src/c_api/c_api_symbolic.cc | 6 +++- .../partitioner/custom_subgraph_property.h | 36 +++++++++++++++++-- 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index aeb5f79e2f70..29640c672923 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -917,10 +917,15 @@ typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *j const char* const* opt_vals, int num_opts); #define MXLIB_PARTCALLACCEPTSUBGRAPH_STR "_partCallAcceptSubgraph" -typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph, const char *json, - int subgraph_id, int *accept, const char* const* opt_keys, +typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph, + const char *json, int subgraph_id, + int *accept, const char* const* opt_keys, const char* const* opt_vals, int num_opts, - char*** attr_keys, char*** attr_vals, int *num_attrs); + char*** attr_keys, char*** attr_vals, + int *num_attrs, const char* const* in_args_chars, + void* const* in_args_data, + const int64_t* const *in_args_shapes, + const int* in_args_dims, const int* in_args_types); #define MXLIB_INITIALIZE_STR "initialize" typedef int (*initialize_t)(int version); @@ -1283,7 +1288,10 @@ extern "C" { _partCallAcceptSubgraph(acceptSubgraph_t acceptSubgraph, const char *json, int subgraph_id, int *accept, const char* const* opt_keys, const char* const* opt_vals, int num_opts, - char*** attr_keys, char*** attr_vals, int *num_attrs) { + char*** attr_keys, char*** attr_vals, int *num_attrs, + const char* const* in_args_chars, void* const* in_args_data, + const int64_t* const* in_args_shapes, int* in_args_dims, + int* in_args_types) { std::string subgraph_json(json); bool accept_bool = false; // create map of attributes from list diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 0776bc701dd7..4957dac94ed2 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1351,8 +1351,9 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, nnvm::Symbol *sym = static_cast(sym_handle); *s = sym->Copy(); nnvm::Graph g = Symbol2Graph(*s); + NDArray **in_args_ptr = nullptr; if (len) { - NDArray **in_args_ptr = reinterpret_cast(in_args_handle); + in_args_ptr = reinterpret_cast(in_args_handle); Context default_ctx = Context::Create(static_cast(dev_type), 0); mxnet::ShapeVector arg_shapes(len); nnvm::DTypeVector arg_dtypes(len); @@ -1382,6 +1383,9 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, g.GetAttr("storage_type")); } } + g.attrs["args"] = std::make_shared(in_args_ptr); + std::vector names = sym->ListInputNames(nnvm::Symbol::ListInputOption(1)); + g.attrs["arg_names"] = std::make_shared(names); std::vector> options_map; for (mx_uint i = 0; i < num_options; ++i) { options_map.emplace_back(keys[i], vals[i]); diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h index 5d0629c25190..351138b4c73b 100644 --- a/src/operator/subgraph/partitioner/custom_subgraph_property.h +++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h @@ -99,7 +99,27 @@ class CustomSubgraphProperty: public SubgraphProperty { const std::vector>& options_map) { // clear supported_nodes to remove state from previous calls supported_nodes.clear(); + + in_args_ptr = g.GetAttr("args"); + in_args_names = g.GetAttr>("arg_names"); + for(std::string s : in_args_names) { + in_args_chars.push_back(s.c_str()); + } + + in_args_data.clear(); + in_args_shapes.clear(); + in_args_dims.clear(); + in_args_types.clear(); + + // convert NDarrays to constituent parts + for (size_t i = 0; i < in_args_names.size(); i++) { + in_args_data.push_back(in_args_ptr[i]->data().dptr_); + in_args_shapes.push_back(in_args_ptr[i]->shape().data()); + in_args_dims.push_back(in_args_ptr[i]->shape().ndim()); + in_args_types.push_back(in_args_ptr[i]->dtype()); + } + // remove all graph attrs, some cannot be saved to json nnvm::Graph graph = std::move(g); graph.attrs.clear(); @@ -189,9 +209,12 @@ class CustomSubgraphProperty: public SubgraphProperty { std::string subgraph_json = nnvm::pass::SaveJSON(g); CHECK(call_accept_subgraph_(accept_subgraph_, subgraph_json.c_str(), - subgraph_id, &accept, opt_keys_.data(), - opt_vals_.data(), opt_keys_.size(), - &attr_keys, &attr_vals, &num_attr)) + subgraph_id, &accept, opt_keys_.data(), + opt_vals_.data(), opt_keys_.size(), + &attr_keys, &attr_vals, &num_attr, + in_args_chars.data(), in_args_data.data(), + in_args_shapes.data(), in_args_dims.data(), + in_args_types.data())) << "Error calling accept_subgraph for '" << subgraph_prop << "'"; } if (accept) { @@ -228,6 +251,13 @@ class CustomSubgraphProperty: public SubgraphProperty { std::string subgraph_op_name; std::vector> options_map_; std::vector opt_keys_, opt_vals_; + NDArray **in_args_ptr; + std::vector in_args_names; + std::vector in_args_chars; + std::vector in_args_data; + std::vector in_args_shapes; + std::vector in_args_dims; + std::vector in_args_types; }; } // namespace op } // namespace mxnet