From 30d543c301463a04d8e563686b6db385845f3d54 Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 21 Jun 2017 15:59:31 -0700 Subject: [PATCH 1/2] Move InferAttr to mxnet from nnvm Replace nnvm infer attr functions in c_api Initial checkin Clean up Remove nnvm namespace for FInferShape, FInferType, and FInferStorageType Add new interface for InferStorageType Revert "Remove nnvm namespace for FInferShape, FInferType, and FInferStorageType" This reverts commit 8aedf054bfe29b076c6fcb6f54d996fd2752e4de. Fix and clean up Fix lint Add nnvm changes Change infer function interface to accept only rvalue reference of graph Clean up Flush commits to show up in PR Add error handling for storage type inference failure Update nnvm --- include/mxnet/c_api.h | 15 - include/mxnet/op_attr_types.h | 6 + nnvm | 2 +- python/mxnet/symbol.py | 83 ----- src/c_api/c_api_ndarray.cc | 4 +- src/c_api/c_api_symbolic.cc | 57 +-- src/c_api/c_predict_api.cc | 3 +- src/executor/exec_pass.h | 41 +++ src/executor/graph_executor.cc | 49 ++- src/executor/infer_graph_attr_pass.cc | 337 ++++++++++++++++++ src/operator/elemwise_op_common.h | 8 +- src/operator/nn/cast_storage-inl.h | 1 + src/operator/nn/cast_storage.cc | 2 +- .../tensor/elemwise_binary_op_basic.cc | 4 +- src/operator/tensor/elemwise_unary_op.cc | 2 +- src/operator/tensor/indexing_op.cc | 6 +- src/operator/tensor/indexing_op.h | 3 + src/operator/tensor/matrix_op-inl.h | 2 + src/operator/tensor/matrix_op.cc | 6 +- tests/python/unittest/test_infer_shape.py | 14 - 20 files changed, 454 insertions(+), 191 deletions(-) create mode 100644 src/executor/infer_graph_attr_pass.cc diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index b1ae3e70bb70..c3fb00e562d1 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1067,21 +1067,6 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, -/*! - * \brief infer storage type of unknown input types given the known one. - */ -MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const int *arg_storage_type_data, - mx_uint *in_storage_type_size, - const int **in_storage_type_data, - mx_uint *out_storage_type_size, - const int **out_storage_type_data, - mx_uint *aux_storage_type_size, - const int **aux_storage_type_data, - int *complete); - //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index cffca441e4b0..6de6e6bf479c 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -71,6 +71,12 @@ using FComputeEx = std::function& inputs, const std::vector& req, const std::vector& outputs)>; + +using FInferStorageType = std::function* in_attrs, + std::vector* out_attrs)>; + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/nnvm b/nnvm index 2e3561500de9..d02104dca1ee 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 2e3561500de99a0c173f3bc7b1a6c2b31435d6d9 +Subproject commit d02104dca1eeb174a063aa06b54b774875a9106f diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index e752eb541648..4b2421b5eb52 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -723,89 +723,6 @@ def list_auxiliary_states(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] - def infer_storage_type(self, *args, **kwargs): - """Infer the storage type of outputs and arguments of given known types of arguments. - - User can either pass in the known types in positional way or keyword argument way. - Tuple of Nones is returned if there is not enough information passed in. - An error will be raised if there is inconsistency found in the known types passed in. - - Parameters - ---------- - *args : - Provide type of arguments in a positional way. - Unknown type can be marked as None - - **kwargs : - Provide keyword arguments of known types. - - Returns - ------- - arg_storage_types : list of numpy.dtype or None - List of types of arguments. - The order is in the same order as list_arguments() - out_storage_types : list of numpy.dtype or None - List of types of outputs. - The order is in the same order as list_outputs() - aux_storage_types : list of numpy.dtype or None - List of types of outputs. - The order is in the same order as list_auxiliary_states() - """ - # pylint: disable=too-many-locals - if len(args) != 0 and len(kwargs) != 0: - raise ValueError('Can only specify known argument \ - types either by positional or kwargs way.') - sdata = [] - if len(args) != 0: - keys = None - for s in args: - if s is not None: - if s not in _STORAGE_TYPE_STR_TO_ID or not isinstance(s, basestring): - raise TypeError('Argument need to be one of '+str(_STORAGE_TYPE_STR_TO_ID)) - sdata.append(_STORAGE_TYPE_STR_TO_ID[s]) - else: - sdata.append(_STORAGE_TYPE_STR_TO_ID['undefined']) - else: - keys = [] - for k, v in kwargs.items(): - if v in _STORAGE_TYPE_STR_TO_ID: - keys.append(c_str(k)) - sdata.append(_STORAGE_TYPE_STR_TO_ID[v]) - arg_storage_type_size = mx_uint() - arg_storage_type_data = ctypes.POINTER(ctypes.c_int)() - out_storage_type_size = mx_uint() - out_storage_type_data = ctypes.POINTER(ctypes.c_int)() - aux_storage_type_size = mx_uint() - aux_storage_type_data = ctypes.POINTER(ctypes.c_int)() - complete = ctypes.c_int() - check_call(_LIB.MXSymbolInferStorageType( - self.handle, - mx_uint(len(sdata)), - c_array(ctypes.c_char_p, keys), - c_array(ctypes.c_int, sdata), - ctypes.byref(arg_storage_type_size), - ctypes.byref(arg_storage_type_data), - ctypes.byref(out_storage_type_size), - ctypes.byref(out_storage_type_data), - ctypes.byref(aux_storage_type_size), - ctypes.byref(aux_storage_type_data), - ctypes.byref(complete))) - if complete.value != 0: - arg_storage_types = [ - _STORAGE_TYPE_ID_TO_STR[arg_storage_type_data[i]] \ - for i in range(arg_storage_type_size.value)] - out_storage_types = [ - _STORAGE_TYPE_ID_TO_STR[out_storage_type_data[i]] \ - for i in range(out_storage_type_size.value)] - aux_storage_types = [ - _STORAGE_TYPE_ID_TO_STR[aux_storage_type_data[i]] \ - for i in range(aux_storage_type_size.value)] - return (arg_storage_types, out_storage_types, aux_storage_types) - else: - return (None, None, None) - # pylint: enable=too-many-locals - - def infer_type(self, *args, **kwargs): """Infers the type of all arguments and all outputs, given the known types for some arguments. diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 8d190597ab0b..3dd491ea2c30 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -135,7 +135,7 @@ void SetShapeType(const nnvm::Op* op, std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); - static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); // infer shape std::vector& in_shapes = ret->arg_shapes; @@ -184,7 +184,7 @@ void SetShapeType(const nnvm::Op* op, out_storage_types.push_back(i.storage_type()); } if (inferstorage.count(op)) { - CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types)); + CHECK(inferstorage[op](attrs, ctx, &in_storage_types, &out_storage_types)); CHECK_EQ(out_storage_types.size(), static_cast(infered_num_outputs)); } else { #if IMPERATIVE_EXEC_DEBUG diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index f4737fa8b3e2..f3aab484ac48 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -11,6 +11,7 @@ #include #include "./c_api_common.h" #include "../operator/operator_common.h" +#include "../executor/exec_pass.h" namespace mxnet { namespace op { @@ -457,7 +458,7 @@ int MXSymbolInferShape(SymbolHandle sym, } try { - g = nnvm::pass::InferShape(std::move(g), arg_shapes, "__shape__"); + g = mxnet::exec::InferShape(std::move(g), arg_shapes, "__shape__"); } catch (const mxnet::op::InferShapeError &err) { throw dmlc::Error(err.msg); } @@ -512,58 +513,6 @@ int MXSymbolInferShapePartial(SymbolHandle sym, &succ); } -// TODO(haibin) refactor with infer_type -int MXSymbolInferStorageType(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const int *arg_storage_type_data, - mx_uint *in_storage_type_size, - const int **in_storage_type_data, - mx_uint *out_storage_type_size, - const int **out_storage_type_data, - mx_uint *aux_storage_type_size, - const int **aux_storage_type_data, - int *complete) { - nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - API_BEGIN(); - nnvm::Graph g = Symbol2Graph(*s); - nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size(), - kUndefinedStorage); - if (keys == nullptr && num_args != 0) { - std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); - CHECK_LE(num_args, read_only_args.size()); - for (mx_uint i = 0; i < num_args; ++i) { - arg_storage_types[read_only_args[i]] = arg_storage_type_data[i]; - } - } else { - std::unordered_map kwargs; - for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = arg_storage_type_data[i]; - } - mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_storage_types, "InferStorageType"); - } - - g = nnvm::pass::InferStorageType(std::move(g), arg_storage_types, "__storage_type__"); - // copy back - CopyAttr(g.indexed_graph(), g.GetAttr("storage_type"), - &(ret->arg_storage_types), &(ret->out_storage_types), &(ret->aux_storage_types)); - - *in_storage_type_size = static_cast(ret->arg_storage_types.size()); - *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); - *out_storage_type_size = static_cast(ret->out_storage_types.size()); - *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); - *in_storage_type_size = static_cast(ret->arg_storage_types.size()); - *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); - *out_storage_type_size = static_cast(ret->out_storage_types.size()); - *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); - *aux_storage_type_size = static_cast(ret->aux_storage_types.size()); - *aux_storage_type_data = dmlc::BeginPtr(ret->aux_storage_types); - *complete = (g.GetAttr("storage_type_num_unknown_nodes") == 0); - API_END(); -} - - int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, @@ -594,7 +543,7 @@ int MXSymbolInferType(SymbolHandle sym, mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType"); } - g = nnvm::pass::InferType(std::move(g), arg_types, "__dtype__"); + g = mxnet::exec::InferType(std::move(g), arg_types, "__dtype__"); // copy back CopyAttr(g.indexed_graph(), g.GetAttr("dtype"), &(ret->arg_types), &(ret->out_types), &(ret->aux_types)); diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index 1dd784ba2249..0bee6cf9f838 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -14,6 +14,7 @@ #include #include "./c_api_common.h" #include "../operator/operator_common.h" +#include "../executor/exec_pass.h" using namespace mxnet; @@ -176,7 +177,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str, } } nnvm::Graph g; g.outputs = sym.outputs; - g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__"); + g = mxnet::exec::InferShape(std::move(g), in_shapes, "__shape__"); bool infer_complete = (g.GetAttr("shape_num_unknown_nodes") == 0); CHECK(infer_complete) << "The shape information of is not enough to get the shapes"; diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 20535be320d9..9be2d6c2f672 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -10,8 +10,10 @@ #include #include #include +#include #include #include +#include namespace mxnet { namespace exec { @@ -107,6 +109,45 @@ Graph AttachOpResources(Graph g); */ Graph DetectInplaceAddTo(Graph g); +/*! + * \brief Infer shapes in the graph given the information. + * \param graph The input graph. + * \param shape_inputs The shapes of input symbols to the graph. + * \param shape_attr_key The key to the node attribute that can indicate shape. This is + * the place where manual hint for shapes could be injected. + * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferShape(Graph graph, + nnvm::ShapeVector shape_inputs, + const std::string& shape_attr_key = ""); + +/*! + * \brief Infer types in the graph given the information. + * \param graph The input graph. + * \param dtype_inputs The types of input symbols to the graph. + * \param dtype_attr_key The key to the node attribute that can indicate types. This is + * the place where manual hint for types could be injected. + * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferType(Graph graph, + nnvm::DTypeVector dtype_inputs, + const std::string& dtype_attr_key = ""); + +/*! + * \brief Infer storage types in the graph given the information. + * \param graph The input graph. + * \param storage_type_inputs The storage types of input symbols to the graph. + * \param storage_type_attr_key The key to the node attribute that can indicate storage types. + This is the place where manual hint for types could be injected. + * \return A graph with new attribute "storage_type" containing inferred type of each NodeEntry. + * The index of StorageTypeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferStorageType(Graph graph, + nnvm::StorageTypeVector storage_type_inputs, + const std::string& storage_type_attr_key = ""); + } // namespace exec } // namespace mxnet diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index de8411a7be95..cf428ff5701d 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -430,6 +430,29 @@ void HandleInferTypeError(const size_t num_forward_inputs, << oss.str(); } +void HandleInferStorageTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::StorageTypeVector& inferred_stypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_stype = inferred_stypes[eid]; + if (inferred_stype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_stype << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferStoragetType pass cannot decide storage type for the following arguments " + "(-1 means unknown stype). Please consider providing them as inputs:\n" + << oss.str(); +} + /*! * \brief GraphExecutor initializer for regular bind flow in which * input arguments and gradients are provided by users. This initializer @@ -501,20 +524,24 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // expand arg_shapes and arg_dtypes to contain backward inputs arg_shapes.resize(idx.input_nodes().size(), TShape()); - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = InferShape(std::move(g), arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("shape")); } arg_dtypes.resize(idx.input_nodes().size(), -1); - g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = InferType(std::move(g), arg_dtypes, "__dtype__"); if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } - // TODO(haibin) better error message for infer_storage - g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); + + g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -877,7 +904,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::IndexedGraph& idx = g.indexed_graph(); nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); - nnvm::DTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + nnvm::StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; @@ -894,19 +921,23 @@ void GraphExecutor::Init(nnvm::Symbol symbol, arg_stypes[i] = it3->second; } } - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = InferShape(std::move(g), arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("shape")); } - g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = InferType(std::move(g), arg_dtypes, "__dtype__"); if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } - // TODO(jun/haibin) check if InferShape is successful, and give warnings instead of segfault later - g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); + + g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc new file mode 100644 index 000000000000..3789c313bf18 --- /dev/null +++ b/src/executor/infer_graph_attr_pass.cc @@ -0,0 +1,337 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file infer_graph_attr_pass.cc + * \brief infer graph shape, dtype, and storage type + */ + +#include +#include "./exec_pass.h" + +namespace mxnet { +namespace exec { + +template +bool ApplyOpInferAttr(const nnvm::Graph& g, + const FInfer& finfer, + const NodeAttrs& attrs, + const uint32_t nid, + std::vector* in_attrs, + std::vector* out_attrs) { + return finfer(attrs, in_attrs, out_attrs); +} + +template<> +bool ApplyOpInferAttr(const nnvm::Graph& g, + const FInferStorageType& finfer, + const NodeAttrs& attrs, + const uint32_t nid, + std::vector* in_attrs, + std::vector* out_attrs) { + const ContextVector& ctxes = g.GetAttr("context"); + return finfer(attrs, ctxes[nid], in_attrs, out_attrs); +} + +/*!\brief + * This is a duplicate of the InferAttr function in nnvm with minor modification + * to support inferring storage type whose function signature is different from + * shape/type inference functions'. The nnvm InferAttr will be deprecated + * in the future. Please use interfaces InferShape, InferType, and InferStorageType + * to call this function. + */ +template +nnvm::Graph InferAttr(nnvm::Graph &&ret, + const AttrType empty_val, + const char* infer_name, + const char* input_name, + const char* attr_key_name, + const char* attr_name, + const char* unknown_name, + IsNone fis_none, + FDefault fdefault, + bool backward_identity_assign) { + using nnvm::IndexedGraph; + using nnvm::Op; + using AttrVector = std::vector; + using dmlc::any; + + const IndexedGraph& idx = ret.indexed_graph(); + static auto& finfer_shape = + Op::GetAttr(infer_name); + static auto& is_backward = + Op::GetAttr("TIsBackward"); + // gradient function, used to get node correspondence. + static auto& fgrad = + Op::GetAttr("FGradient"); + // reshape shape vector + AttrVector rshape; + if (ret.attrs.count(attr_name) != 0) { + rshape = ret.MoveCopyAttr(attr_name); + } else { + rshape.resize(idx.num_node_entries(), empty_val); + } + + if (ret.attrs.count(input_name) != 0) { + const AttrVector& shape_args = ret.GetAttr(input_name); + CHECK_LE(shape_args.size(), idx.input_nodes().size()) + << "More provided " << attr_name << "s than number of arguments."; + for (size_t i = 0; i < shape_args.size(); ++i) { + rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; + } + // erase the provided arguments + ret.attrs.erase(input_name); + } + + // get the shape hints + std::string shape_hints_key = std::string(attr_name) + "_hints"; + if (ret.attrs.count(shape_hints_key)) { + nnvm::NodeEntryMap shape_hints = + ret.GetAttr>(shape_hints_key); + for (const auto& kv : shape_hints) { + nnvm::NodeEntry e = kv.first; + if (idx.exist(e.node.get())) { + rshape[idx.entry_id(kv.first)] = kv.second; + } + } + } + + std::string shape_attr_key; + if (ret.attrs.count(attr_key_name) != 0) { + shape_attr_key = ret.GetAttr(attr_key_name); + // erase the provided arguments + ret.attrs.erase(attr_key_name); + } + // Temp space for shape inference. + std::vector ishape, oshape; + + // inference step function for nid + auto infer_step = [&](uint32_t nid, bool last_iter) { + const auto& inode = idx[nid]; + const uint32_t num_inputs = inode.inputs.size(); + const uint32_t num_outputs = inode.source->num_outputs(); + if (inode.source->is_variable()) { + // Variable node. No operator. Only one output entry. + CHECK(inode.source->op() == nullptr); + CHECK_EQ(num_outputs, 1U); + const uint32_t out_ent_id = idx.entry_id(nid, 0); + if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { + auto it = inode.source->attrs.dict.find(shape_attr_key); + if (it != inode.source->attrs.dict.end()) { + std::istringstream is(it->second); + CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; + } + } + } else if (is_backward.get(inode.source->op(), false) && + inode.control_deps.size() && backward_identity_assign) { + CHECK_GE(inode.control_deps.size(), 1U) + << "BackwardOp need to have control_deps to its forward op"; + const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; + nnvm::NodePtr fwd_ptr = inode.source->control_deps[0]; + CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; + // use gradient function to find out the correspondence. + std::vector ograd(fwd_ptr->num_outputs()); + for (size_t i = 0; i < ograd.size(); ++i) { + ograd[i].index = static_cast(i); + } + // input gradient list + auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); + const nnvm::Node* igrad_node = nullptr; + // Input gradient assignement + for (size_t i = 0; i < igrad.size(); ++i) { + if (igrad[i].node->op() == inode.source->op()) { + uint32_t eid = idx.entry_id(nid, igrad[i].index); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; + } else { + CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) + << "Backward shape inconsistent with the forward shape"; + } + if (igrad_node == nullptr) { + igrad_node = igrad[i].node.get(); + } else { + CHECK(igrad_node == igrad[i].node.get()); + } + } + } + // out grad entries + CHECK(igrad_node != nullptr) + << "Cannot find matching backward op for " << inode.source->attrs.name; + for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { + const nnvm::NodeEntry& e = igrad_node->inputs[i]; + if (e.node == nullptr) { + uint32_t eid = idx.entry_id(inode.inputs[i]); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; + } + } + } + } else { + bool forward_known = true; + // Forward operator inference. + ishape.resize(num_inputs, empty_val); + for (uint32_t i = 0; i < ishape.size(); ++i) { + ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; + if (fis_none(ishape[i])) forward_known = false; + } + oshape.resize(num_outputs, empty_val); + for (uint32_t i = 0; i < oshape.size(); ++i) { + oshape[i] = rshape[idx.entry_id(nid, i)]; + if (fis_none(oshape[i])) forward_known = false; + } + auto finfer = finfer_shape.get(inode.source->op(), fdefault); + if (!forward_known) { + if (finfer != nullptr) { + // Call inference function of the operator. + try { + forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape); + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + } + } else { + CHECK(!last_iter) + << "Attribute " << infer_name + << " is not registed by op " << inode.source->op()->name + << " we are not able to complete the inference because of this"; + } + } + // Save to the result map. + for (uint32_t i = 0; i < num_inputs; ++i) { + rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; + } + for (uint32_t i = 0; i < num_outputs; ++i) { + rshape[idx.entry_id(nid, i)] = oshape[i]; + } + } + }; + + size_t last_num_unknown; + size_t num_unknown = rshape.size(); + int i = 0; + do { + if (i % 2 == 0) { + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + infer_step(nid, false); + } + } else { + // backward inference + for (uint32_t i = idx.num_nodes(); i != 0; --i) { + infer_step(i - 1, false); + } + } + last_num_unknown = num_unknown; + num_unknown = 0; + for (size_t j = 0; j < idx.num_node_entries(); ++j) { + if (fis_none(rshape[j])) { + ++num_unknown; + } + } + ++i; + } while (num_unknown > 0 && last_num_unknown > num_unknown); + // set the shapes + ret.attrs[attr_name] = std::make_shared(std::move(rshape)); + // number of nodes who knows the shape. + ret.attrs[unknown_name] = std::make_shared(num_unknown); + return ret; +} + +// inference fucntion for same type +inline bool SameType(const nnvm::NodeAttrs& attrs, + std::vector *iattr, + std::vector *oattr) { + int def_v = -1; + for (int v : *oattr) { + if (v != -1) { + def_v = v; break; + } + } + if (def_v == -1) { + for (int v : *iattr) { + if (v != -1) { + def_v = v; break; + } + } + } + if (def_v == -1) return false; + for (int& v : *oattr) { + v = def_v; + } + for (int& v : *iattr) { + v = def_v; + } + return true; +} + +// assigning default type N to both input and output attrs with value -1 +template +inline bool DefaultType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *iattr, + std::vector *oattr) { + // TODO(junwu): check whether need to use ctx + for (int& v : *oattr) { + if (v == none) v = default_val; + } + for (int& v : *iattr) { + if (v == none) v = default_val; + } + return true; +} + +nnvm::Graph InferShape(nnvm::Graph graph, + nnvm::ShapeVector shape_inputs, + const std::string& shape_attr_key) { + using dmlc::any; + if (shape_inputs.size() != 0) { + graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); + } + if (shape_attr_key.length() != 0) { + graph.attrs["shape_attr_key"] = std::make_shared(std::move(shape_attr_key)); + } + return InferAttr( + std::move(graph), nnvm::TShape(), + "FInferShape", "shape_inputs", "shape_attr_key", + "shape", "shape_num_unknown_nodes", + [](const nnvm::TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + nullptr, true); +} + +nnvm::Graph InferType(nnvm::Graph graph, + nnvm::DTypeVector dtype_inputs, + const std::string& dtype_attr_key) { + using dmlc::any; + if (dtype_inputs.size() != 0) { + graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); + } + if (dtype_attr_key.length() != 0) { + graph.attrs["dtype_attr_key"] = std::make_shared(std::move(dtype_attr_key)); + } + return InferAttr( + std::move(graph), -1, + "FInferType", "dtype_inputs", "dtype_attr_key", + "dtype", "dtype_num_unknown_nodes", + [](const int t) { return t == -1; }, + SameType, true); +} + +nnvm::Graph InferStorageType(nnvm::Graph graph, + nnvm::StorageTypeVector storage_type_inputs, + const std::string& storage_type_attr_key) { + using dmlc::any; + if (storage_type_inputs.size() != 0) { + graph.attrs["storage_type_inputs"] = std::make_shared(std::move(storage_type_inputs)); + } + if (storage_type_attr_key.length() != 0) { + graph.attrs["storage_type_attr_key"] = std::make_shared(std::move(storage_type_attr_key)); + } + // for storage type, the backward attr is not necessarily the same as it's correspondence + const int kDefaultStorage = 0; + return InferAttr( + std::move(graph), -1, + "FInferStorageType", "storage_type_inputs", "storage_type_attr_key", + "storage_type", "storage_type_num_unknown_nodes", + [](const int t) { return t == -1; }, + DefaultType, false); +} + +} // namespace exec +} // namespace mxnet diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 3f2000f6ee99..441add472339 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -111,8 +111,10 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, template inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; return ElemwiseStorageAttr( @@ -120,8 +122,10 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, } inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; auto &in = *in_attrs; diff --git a/src/operator/nn/cast_storage-inl.h b/src/operator/nn/cast_storage-inl.h index 1fb32045b9a0..003161f8797a 100644 --- a/src/operator/nn/cast_storage-inl.h +++ b/src/operator/nn/cast_storage-inl.h @@ -302,6 +302,7 @@ struct CastStorageParam : public dmlc::Parameter { }; inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); diff --git a/src/operator/nn/cast_storage.cc b/src/operator/nn/cast_storage.cc index 21c13e8fa564..c435146a730b 100644 --- a/src/operator/nn/cast_storage.cc +++ b/src/operator/nn/cast_storage.cc @@ -21,7 +21,7 @@ NNVM_REGISTER_OP(cast_storage) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FInferStorageType", CastStorageInferStorageType) +.set_attr("FInferStorageType", CastStorageInferStorageType) .set_attr("FCompute", IdentityCompute) .set_attr("FComputeEx", CastStorageComputeEx) .add_argument("data", "NDArray-or-Symbol", "The input.") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index c9e5b21470d9..37e073172d28 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -14,7 +14,7 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .set_attr("FCompute", BinaryCompute) .set_attr("FComputeEx", BinaryComputeEx) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}) -.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); +.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); // specialized gradient add function to do add to optimization // this must differ from elemwise_add to prevent add to optimization in forward pass. @@ -33,7 +33,7 @@ NNVM_REGISTER_OP(_backward_add) mshadow_op::identity>) .set_attr("FComputeEx", BinaryBackwardUseNoneEx) -.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); +.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); MXNET_OPERATOR_REGISTER_BINARY(_sub) .add_alias("_minus").add_alias("_Minus") diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index 372e94509a68..078d62b5f96e 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -126,7 +126,7 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_attr("FCompute", IdentityCompute) .set_attr("FComputeEx", IdentityLikeRhsComputeEx) .set_attr("FInferShape", ElemwiseShape<2, 1>) -.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) +.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index dfe53cf4614e..f55f7d8cf563 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -107,7 +107,7 @@ The gradient of an embedding matrix has the form of gradient vectors that are on }) .set_attr("FInferShape", SparseEmbeddingShape) .set_attr("FInferType", EmbeddingOpType) -.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) +.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; @@ -296,7 +296,7 @@ Example:: }) .set_attr("FInferShape", SparseRetainOpShape) .set_attr("FInferType", SparseRetainOpType) -.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) +.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) .set_attr("FComputeEx", SparseRetainOpForwardEx) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { @@ -310,7 +310,7 @@ NNVM_REGISTER_OP(_backward_sparse_retain) .set_num_inputs(2) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) +.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) .set_attr("FComputeEx", SparseRetainOpBackwardEx); } // namespace op diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index b2a67f73af78..213c2d2c3313 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -247,6 +247,7 @@ void SparseEmbeddingForwardEx(const nnvm::NodeAttrs& attrs, } inline bool SparseEmbeddingForwardStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); @@ -800,6 +801,7 @@ inline bool SparseRetainOpType(const nnvm::NodeAttrs& attrs, } inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); @@ -811,6 +813,7 @@ inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, } inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 8ba10bfd0e27..0aad920d641c 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -478,6 +478,7 @@ void DotBackward_(const nnvm::NodeAttrs& attrs, } inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); @@ -487,6 +488,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3U); diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 9ac998f02378..72d8aadbe90a 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -248,7 +248,7 @@ Example:: .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", Slice) .set_attr("FComputeEx", SliceEx) @@ -375,7 +375,7 @@ NNVM_REGISTER_OP(dot) }) .set_attr("FInferShape", DotShape) .set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FInferStorageType", DotForwardInferStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; @@ -392,7 +392,7 @@ NNVM_REGISTER_OP(_backward_dot) .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FInferStorageType", DotBackwardInferStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 9188dd9d933f..ceb965b43a72 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -129,19 +129,6 @@ def test_fc_infer_type(): for k, v in true_types.items(): assert arg_type_dict[k] == v -def check_infer_storage(v1, v2, v1_storage, v2_storage, out_chunk): - out = mx.symbol.elemwise_add(v1, v2) - arg_storage_types, out_storage_types, aux_storage_types = out.infer_storage_type(v1=v1_storage, v2=v2_storage) - assert len(out_storage_types) == 1 - assert out_storage_types[0] == out_chunk - -def test_elemwise_add_infer_storage_type(): - v1 = mx.symbol.Variable('v1') - v2 = mx.symbol.Variable('v2') - check_infer_storage(v1, v2, 'default', 'default', 'default') - check_infer_storage(v1, v2, 'default', 'row_sparse', 'default') - check_infer_storage(v1, v2, 'row_sparse', 'default', 'default') - check_infer_storage(v1, v2, 'row_sparse', 'row_sparse', 'row_sparse') if __name__ == "__main__": test_mlp2_infer_shape() @@ -152,4 +139,3 @@ def test_elemwise_add_infer_storage_type(): test_incomplete_infer_slicechannel() test_incomplete_infer_convolution() test_incomplete_infer_concat() - test_elemwise_add_infer_storage_type() From 78b21ef78c3f58816a9517071e4f2afbc6c4ec69 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 26 Jun 2017 22:05:36 -0700 Subject: [PATCH 2/2] Fix pylint --- python/mxnet/symbol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 4b2421b5eb52..796ca77eaa13 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -19,7 +19,7 @@ from .context import Context, cpu from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .name import NameManager # pylint: disable=unused-import -from .ndarray import _STORAGE_TYPE_ID_TO_STR, _STORAGE_TYPE_STR_TO_ID +from .ndarray import _STORAGE_TYPE_STR_TO_ID from .sparse_ndarray import _ndarray_cls from .executor import Executor from . import _symbol_internal as _internal