diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index b1ae3e70bb70..c3fb00e562d1 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1067,21 +1067,6 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, -/*! - * \brief infer storage type of unknown input types given the known one. - */ -MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const int *arg_storage_type_data, - mx_uint *in_storage_type_size, - const int **in_storage_type_data, - mx_uint *out_storage_type_size, - const int **out_storage_type_data, - mx_uint *aux_storage_type_size, - const int **aux_storage_type_data, - int *complete); - //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index cffca441e4b0..6de6e6bf479c 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -71,6 +71,12 @@ using FComputeEx = std::function& inputs, const std::vector& req, const std::vector& outputs)>; + +using FInferStorageType = std::function* in_attrs, + std::vector* out_attrs)>; + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/nnvm b/nnvm index 2e3561500de9..d02104dca1ee 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 2e3561500de99a0c173f3bc7b1a6c2b31435d6d9 +Subproject commit d02104dca1eeb174a063aa06b54b774875a9106f diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index e752eb541648..796ca77eaa13 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -19,7 +19,7 @@ from .context import Context, cpu from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .name import NameManager # pylint: disable=unused-import -from .ndarray import _STORAGE_TYPE_ID_TO_STR, _STORAGE_TYPE_STR_TO_ID +from .ndarray import _STORAGE_TYPE_STR_TO_ID from .sparse_ndarray import _ndarray_cls from .executor import Executor from . import _symbol_internal as _internal @@ -723,89 +723,6 @@ def list_auxiliary_states(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] - def infer_storage_type(self, *args, **kwargs): - """Infer the storage type of outputs and arguments of given known types of arguments. - - User can either pass in the known types in positional way or keyword argument way. - Tuple of Nones is returned if there is not enough information passed in. - An error will be raised if there is inconsistency found in the known types passed in. - - Parameters - ---------- - *args : - Provide type of arguments in a positional way. - Unknown type can be marked as None - - **kwargs : - Provide keyword arguments of known types. - - Returns - ------- - arg_storage_types : list of numpy.dtype or None - List of types of arguments. - The order is in the same order as list_arguments() - out_storage_types : list of numpy.dtype or None - List of types of outputs. - The order is in the same order as list_outputs() - aux_storage_types : list of numpy.dtype or None - List of types of outputs. - The order is in the same order as list_auxiliary_states() - """ - # pylint: disable=too-many-locals - if len(args) != 0 and len(kwargs) != 0: - raise ValueError('Can only specify known argument \ - types either by positional or kwargs way.') - sdata = [] - if len(args) != 0: - keys = None - for s in args: - if s is not None: - if s not in _STORAGE_TYPE_STR_TO_ID or not isinstance(s, basestring): - raise TypeError('Argument need to be one of '+str(_STORAGE_TYPE_STR_TO_ID)) - sdata.append(_STORAGE_TYPE_STR_TO_ID[s]) - else: - sdata.append(_STORAGE_TYPE_STR_TO_ID['undefined']) - else: - keys = [] - for k, v in kwargs.items(): - if v in _STORAGE_TYPE_STR_TO_ID: - keys.append(c_str(k)) - sdata.append(_STORAGE_TYPE_STR_TO_ID[v]) - arg_storage_type_size = mx_uint() - arg_storage_type_data = ctypes.POINTER(ctypes.c_int)() - out_storage_type_size = mx_uint() - out_storage_type_data = ctypes.POINTER(ctypes.c_int)() - aux_storage_type_size = mx_uint() - aux_storage_type_data = ctypes.POINTER(ctypes.c_int)() - complete = ctypes.c_int() - check_call(_LIB.MXSymbolInferStorageType( - self.handle, - mx_uint(len(sdata)), - c_array(ctypes.c_char_p, keys), - c_array(ctypes.c_int, sdata), - ctypes.byref(arg_storage_type_size), - ctypes.byref(arg_storage_type_data), - ctypes.byref(out_storage_type_size), - ctypes.byref(out_storage_type_data), - ctypes.byref(aux_storage_type_size), - ctypes.byref(aux_storage_type_data), - ctypes.byref(complete))) - if complete.value != 0: - arg_storage_types = [ - _STORAGE_TYPE_ID_TO_STR[arg_storage_type_data[i]] \ - for i in range(arg_storage_type_size.value)] - out_storage_types = [ - _STORAGE_TYPE_ID_TO_STR[out_storage_type_data[i]] \ - for i in range(out_storage_type_size.value)] - aux_storage_types = [ - _STORAGE_TYPE_ID_TO_STR[aux_storage_type_data[i]] \ - for i in range(aux_storage_type_size.value)] - return (arg_storage_types, out_storage_types, aux_storage_types) - else: - return (None, None, None) - # pylint: enable=too-many-locals - - def infer_type(self, *args, **kwargs): """Infers the type of all arguments and all outputs, given the known types for some arguments. diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 8d190597ab0b..3dd491ea2c30 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -135,7 +135,7 @@ void SetShapeType(const nnvm::Op* op, std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); - static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); // infer shape std::vector& in_shapes = ret->arg_shapes; @@ -184,7 +184,7 @@ void SetShapeType(const nnvm::Op* op, out_storage_types.push_back(i.storage_type()); } if (inferstorage.count(op)) { - CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types)); + CHECK(inferstorage[op](attrs, ctx, &in_storage_types, &out_storage_types)); CHECK_EQ(out_storage_types.size(), static_cast(infered_num_outputs)); } else { #if IMPERATIVE_EXEC_DEBUG diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index f4737fa8b3e2..f3aab484ac48 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -11,6 +11,7 @@ #include #include "./c_api_common.h" #include "../operator/operator_common.h" +#include "../executor/exec_pass.h" namespace mxnet { namespace op { @@ -457,7 +458,7 @@ int MXSymbolInferShape(SymbolHandle sym, } try { - g = nnvm::pass::InferShape(std::move(g), arg_shapes, "__shape__"); + g = mxnet::exec::InferShape(std::move(g), arg_shapes, "__shape__"); } catch (const mxnet::op::InferShapeError &err) { throw dmlc::Error(err.msg); } @@ -512,58 +513,6 @@ int MXSymbolInferShapePartial(SymbolHandle sym, &succ); } -// TODO(haibin) refactor with infer_type -int MXSymbolInferStorageType(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const int *arg_storage_type_data, - mx_uint *in_storage_type_size, - const int **in_storage_type_data, - mx_uint *out_storage_type_size, - const int **out_storage_type_data, - mx_uint *aux_storage_type_size, - const int **aux_storage_type_data, - int *complete) { - nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - API_BEGIN(); - nnvm::Graph g = Symbol2Graph(*s); - nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size(), - kUndefinedStorage); - if (keys == nullptr && num_args != 0) { - std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); - CHECK_LE(num_args, read_only_args.size()); - for (mx_uint i = 0; i < num_args; ++i) { - arg_storage_types[read_only_args[i]] = arg_storage_type_data[i]; - } - } else { - std::unordered_map kwargs; - for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = arg_storage_type_data[i]; - } - mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_storage_types, "InferStorageType"); - } - - g = nnvm::pass::InferStorageType(std::move(g), arg_storage_types, "__storage_type__"); - // copy back - CopyAttr(g.indexed_graph(), g.GetAttr("storage_type"), - &(ret->arg_storage_types), &(ret->out_storage_types), &(ret->aux_storage_types)); - - *in_storage_type_size = static_cast(ret->arg_storage_types.size()); - *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); - *out_storage_type_size = static_cast(ret->out_storage_types.size()); - *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); - *in_storage_type_size = static_cast(ret->arg_storage_types.size()); - *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); - *out_storage_type_size = static_cast(ret->out_storage_types.size()); - *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); - *aux_storage_type_size = static_cast(ret->aux_storage_types.size()); - *aux_storage_type_data = dmlc::BeginPtr(ret->aux_storage_types); - *complete = (g.GetAttr("storage_type_num_unknown_nodes") == 0); - API_END(); -} - - int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, @@ -594,7 +543,7 @@ int MXSymbolInferType(SymbolHandle sym, mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType"); } - g = nnvm::pass::InferType(std::move(g), arg_types, "__dtype__"); + g = mxnet::exec::InferType(std::move(g), arg_types, "__dtype__"); // copy back CopyAttr(g.indexed_graph(), g.GetAttr("dtype"), &(ret->arg_types), &(ret->out_types), &(ret->aux_types)); diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index 1dd784ba2249..0bee6cf9f838 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -14,6 +14,7 @@ #include #include "./c_api_common.h" #include "../operator/operator_common.h" +#include "../executor/exec_pass.h" using namespace mxnet; @@ -176,7 +177,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str, } } nnvm::Graph g; g.outputs = sym.outputs; - g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__"); + g = mxnet::exec::InferShape(std::move(g), in_shapes, "__shape__"); bool infer_complete = (g.GetAttr("shape_num_unknown_nodes") == 0); CHECK(infer_complete) << "The shape information of is not enough to get the shapes"; diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 20535be320d9..9be2d6c2f672 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -10,8 +10,10 @@ #include #include #include +#include #include #include +#include namespace mxnet { namespace exec { @@ -107,6 +109,45 @@ Graph AttachOpResources(Graph g); */ Graph DetectInplaceAddTo(Graph g); +/*! + * \brief Infer shapes in the graph given the information. + * \param graph The input graph. + * \param shape_inputs The shapes of input symbols to the graph. + * \param shape_attr_key The key to the node attribute that can indicate shape. This is + * the place where manual hint for shapes could be injected. + * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferShape(Graph graph, + nnvm::ShapeVector shape_inputs, + const std::string& shape_attr_key = ""); + +/*! + * \brief Infer types in the graph given the information. + * \param graph The input graph. + * \param dtype_inputs The types of input symbols to the graph. + * \param dtype_attr_key The key to the node attribute that can indicate types. This is + * the place where manual hint for types could be injected. + * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferType(Graph graph, + nnvm::DTypeVector dtype_inputs, + const std::string& dtype_attr_key = ""); + +/*! + * \brief Infer storage types in the graph given the information. + * \param graph The input graph. + * \param storage_type_inputs The storage types of input symbols to the graph. + * \param storage_type_attr_key The key to the node attribute that can indicate storage types. + This is the place where manual hint for types could be injected. + * \return A graph with new attribute "storage_type" containing inferred type of each NodeEntry. + * The index of StorageTypeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferStorageType(Graph graph, + nnvm::StorageTypeVector storage_type_inputs, + const std::string& storage_type_attr_key = ""); + } // namespace exec } // namespace mxnet diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index de8411a7be95..cf428ff5701d 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -430,6 +430,29 @@ void HandleInferTypeError(const size_t num_forward_inputs, << oss.str(); } +void HandleInferStorageTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::StorageTypeVector& inferred_stypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_stype = inferred_stypes[eid]; + if (inferred_stype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_stype << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferStoragetType pass cannot decide storage type for the following arguments " + "(-1 means unknown stype). Please consider providing them as inputs:\n" + << oss.str(); +} + /*! * \brief GraphExecutor initializer for regular bind flow in which * input arguments and gradients are provided by users. This initializer @@ -501,20 +524,24 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // expand arg_shapes and arg_dtypes to contain backward inputs arg_shapes.resize(idx.input_nodes().size(), TShape()); - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = InferShape(std::move(g), arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("shape")); } arg_dtypes.resize(idx.input_nodes().size(), -1); - g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = InferType(std::move(g), arg_dtypes, "__dtype__"); if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } - // TODO(haibin) better error message for infer_storage - g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); + + g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -877,7 +904,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::IndexedGraph& idx = g.indexed_graph(); nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); - nnvm::DTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + nnvm::StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; @@ -894,19 +921,23 @@ void GraphExecutor::Init(nnvm::Symbol symbol, arg_stypes[i] = it3->second; } } - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = InferShape(std::move(g), arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("shape")); } - g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = InferType(std::move(g), arg_dtypes, "__dtype__"); if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } - // TODO(jun/haibin) check if InferShape is successful, and give warnings instead of segfault later - g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); + + g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc new file mode 100644 index 000000000000..3789c313bf18 --- /dev/null +++ b/src/executor/infer_graph_attr_pass.cc @@ -0,0 +1,337 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file infer_graph_attr_pass.cc + * \brief infer graph shape, dtype, and storage type + */ + +#include +#include "./exec_pass.h" + +namespace mxnet { +namespace exec { + +template +bool ApplyOpInferAttr(const nnvm::Graph& g, + const FInfer& finfer, + const NodeAttrs& attrs, + const uint32_t nid, + std::vector* in_attrs, + std::vector* out_attrs) { + return finfer(attrs, in_attrs, out_attrs); +} + +template<> +bool ApplyOpInferAttr(const nnvm::Graph& g, + const FInferStorageType& finfer, + const NodeAttrs& attrs, + const uint32_t nid, + std::vector* in_attrs, + std::vector* out_attrs) { + const ContextVector& ctxes = g.GetAttr("context"); + return finfer(attrs, ctxes[nid], in_attrs, out_attrs); +} + +/*!\brief + * This is a duplicate of the InferAttr function in nnvm with minor modification + * to support inferring storage type whose function signature is different from + * shape/type inference functions'. The nnvm InferAttr will be deprecated + * in the future. Please use interfaces InferShape, InferType, and InferStorageType + * to call this function. + */ +template +nnvm::Graph InferAttr(nnvm::Graph &&ret, + const AttrType empty_val, + const char* infer_name, + const char* input_name, + const char* attr_key_name, + const char* attr_name, + const char* unknown_name, + IsNone fis_none, + FDefault fdefault, + bool backward_identity_assign) { + using nnvm::IndexedGraph; + using nnvm::Op; + using AttrVector = std::vector; + using dmlc::any; + + const IndexedGraph& idx = ret.indexed_graph(); + static auto& finfer_shape = + Op::GetAttr(infer_name); + static auto& is_backward = + Op::GetAttr("TIsBackward"); + // gradient function, used to get node correspondence. + static auto& fgrad = + Op::GetAttr("FGradient"); + // reshape shape vector + AttrVector rshape; + if (ret.attrs.count(attr_name) != 0) { + rshape = ret.MoveCopyAttr(attr_name); + } else { + rshape.resize(idx.num_node_entries(), empty_val); + } + + if (ret.attrs.count(input_name) != 0) { + const AttrVector& shape_args = ret.GetAttr(input_name); + CHECK_LE(shape_args.size(), idx.input_nodes().size()) + << "More provided " << attr_name << "s than number of arguments."; + for (size_t i = 0; i < shape_args.size(); ++i) { + rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; + } + // erase the provided arguments + ret.attrs.erase(input_name); + } + + // get the shape hints + std::string shape_hints_key = std::string(attr_name) + "_hints"; + if (ret.attrs.count(shape_hints_key)) { + nnvm::NodeEntryMap shape_hints = + ret.GetAttr>(shape_hints_key); + for (const auto& kv : shape_hints) { + nnvm::NodeEntry e = kv.first; + if (idx.exist(e.node.get())) { + rshape[idx.entry_id(kv.first)] = kv.second; + } + } + } + + std::string shape_attr_key; + if (ret.attrs.count(attr_key_name) != 0) { + shape_attr_key = ret.GetAttr(attr_key_name); + // erase the provided arguments + ret.attrs.erase(attr_key_name); + } + // Temp space for shape inference. + std::vector ishape, oshape; + + // inference step function for nid + auto infer_step = [&](uint32_t nid, bool last_iter) { + const auto& inode = idx[nid]; + const uint32_t num_inputs = inode.inputs.size(); + const uint32_t num_outputs = inode.source->num_outputs(); + if (inode.source->is_variable()) { + // Variable node. No operator. Only one output entry. + CHECK(inode.source->op() == nullptr); + CHECK_EQ(num_outputs, 1U); + const uint32_t out_ent_id = idx.entry_id(nid, 0); + if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { + auto it = inode.source->attrs.dict.find(shape_attr_key); + if (it != inode.source->attrs.dict.end()) { + std::istringstream is(it->second); + CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; + } + } + } else if (is_backward.get(inode.source->op(), false) && + inode.control_deps.size() && backward_identity_assign) { + CHECK_GE(inode.control_deps.size(), 1U) + << "BackwardOp need to have control_deps to its forward op"; + const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; + nnvm::NodePtr fwd_ptr = inode.source->control_deps[0]; + CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; + // use gradient function to find out the correspondence. + std::vector ograd(fwd_ptr->num_outputs()); + for (size_t i = 0; i < ograd.size(); ++i) { + ograd[i].index = static_cast(i); + } + // input gradient list + auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); + const nnvm::Node* igrad_node = nullptr; + // Input gradient assignement + for (size_t i = 0; i < igrad.size(); ++i) { + if (igrad[i].node->op() == inode.source->op()) { + uint32_t eid = idx.entry_id(nid, igrad[i].index); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; + } else { + CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) + << "Backward shape inconsistent with the forward shape"; + } + if (igrad_node == nullptr) { + igrad_node = igrad[i].node.get(); + } else { + CHECK(igrad_node == igrad[i].node.get()); + } + } + } + // out grad entries + CHECK(igrad_node != nullptr) + << "Cannot find matching backward op for " << inode.source->attrs.name; + for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { + const nnvm::NodeEntry& e = igrad_node->inputs[i]; + if (e.node == nullptr) { + uint32_t eid = idx.entry_id(inode.inputs[i]); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; + } + } + } + } else { + bool forward_known = true; + // Forward operator inference. + ishape.resize(num_inputs, empty_val); + for (uint32_t i = 0; i < ishape.size(); ++i) { + ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; + if (fis_none(ishape[i])) forward_known = false; + } + oshape.resize(num_outputs, empty_val); + for (uint32_t i = 0; i < oshape.size(); ++i) { + oshape[i] = rshape[idx.entry_id(nid, i)]; + if (fis_none(oshape[i])) forward_known = false; + } + auto finfer = finfer_shape.get(inode.source->op(), fdefault); + if (!forward_known) { + if (finfer != nullptr) { + // Call inference function of the operator. + try { + forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape); + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + } + } else { + CHECK(!last_iter) + << "Attribute " << infer_name + << " is not registed by op " << inode.source->op()->name + << " we are not able to complete the inference because of this"; + } + } + // Save to the result map. + for (uint32_t i = 0; i < num_inputs; ++i) { + rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; + } + for (uint32_t i = 0; i < num_outputs; ++i) { + rshape[idx.entry_id(nid, i)] = oshape[i]; + } + } + }; + + size_t last_num_unknown; + size_t num_unknown = rshape.size(); + int i = 0; + do { + if (i % 2 == 0) { + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + infer_step(nid, false); + } + } else { + // backward inference + for (uint32_t i = idx.num_nodes(); i != 0; --i) { + infer_step(i - 1, false); + } + } + last_num_unknown = num_unknown; + num_unknown = 0; + for (size_t j = 0; j < idx.num_node_entries(); ++j) { + if (fis_none(rshape[j])) { + ++num_unknown; + } + } + ++i; + } while (num_unknown > 0 && last_num_unknown > num_unknown); + // set the shapes + ret.attrs[attr_name] = std::make_shared(std::move(rshape)); + // number of nodes who knows the shape. + ret.attrs[unknown_name] = std::make_shared(num_unknown); + return ret; +} + +// inference fucntion for same type +inline bool SameType(const nnvm::NodeAttrs& attrs, + std::vector *iattr, + std::vector *oattr) { + int def_v = -1; + for (int v : *oattr) { + if (v != -1) { + def_v = v; break; + } + } + if (def_v == -1) { + for (int v : *iattr) { + if (v != -1) { + def_v = v; break; + } + } + } + if (def_v == -1) return false; + for (int& v : *oattr) { + v = def_v; + } + for (int& v : *iattr) { + v = def_v; + } + return true; +} + +// assigning default type N to both input and output attrs with value -1 +template +inline bool DefaultType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *iattr, + std::vector *oattr) { + // TODO(junwu): check whether need to use ctx + for (int& v : *oattr) { + if (v == none) v = default_val; + } + for (int& v : *iattr) { + if (v == none) v = default_val; + } + return true; +} + +nnvm::Graph InferShape(nnvm::Graph graph, + nnvm::ShapeVector shape_inputs, + const std::string& shape_attr_key) { + using dmlc::any; + if (shape_inputs.size() != 0) { + graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); + } + if (shape_attr_key.length() != 0) { + graph.attrs["shape_attr_key"] = std::make_shared(std::move(shape_attr_key)); + } + return InferAttr( + std::move(graph), nnvm::TShape(), + "FInferShape", "shape_inputs", "shape_attr_key", + "shape", "shape_num_unknown_nodes", + [](const nnvm::TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + nullptr, true); +} + +nnvm::Graph InferType(nnvm::Graph graph, + nnvm::DTypeVector dtype_inputs, + const std::string& dtype_attr_key) { + using dmlc::any; + if (dtype_inputs.size() != 0) { + graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); + } + if (dtype_attr_key.length() != 0) { + graph.attrs["dtype_attr_key"] = std::make_shared(std::move(dtype_attr_key)); + } + return InferAttr( + std::move(graph), -1, + "FInferType", "dtype_inputs", "dtype_attr_key", + "dtype", "dtype_num_unknown_nodes", + [](const int t) { return t == -1; }, + SameType, true); +} + +nnvm::Graph InferStorageType(nnvm::Graph graph, + nnvm::StorageTypeVector storage_type_inputs, + const std::string& storage_type_attr_key) { + using dmlc::any; + if (storage_type_inputs.size() != 0) { + graph.attrs["storage_type_inputs"] = std::make_shared(std::move(storage_type_inputs)); + } + if (storage_type_attr_key.length() != 0) { + graph.attrs["storage_type_attr_key"] = std::make_shared(std::move(storage_type_attr_key)); + } + // for storage type, the backward attr is not necessarily the same as it's correspondence + const int kDefaultStorage = 0; + return InferAttr( + std::move(graph), -1, + "FInferStorageType", "storage_type_inputs", "storage_type_attr_key", + "storage_type", "storage_type_num_unknown_nodes", + [](const int t) { return t == -1; }, + DefaultType, false); +} + +} // namespace exec +} // namespace mxnet diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 3f2000f6ee99..441add472339 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -111,8 +111,10 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, template inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; return ElemwiseStorageAttr( @@ -120,8 +122,10 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, } inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; auto &in = *in_attrs; diff --git a/src/operator/nn/cast_storage-inl.h b/src/operator/nn/cast_storage-inl.h index 1fb32045b9a0..003161f8797a 100644 --- a/src/operator/nn/cast_storage-inl.h +++ b/src/operator/nn/cast_storage-inl.h @@ -302,6 +302,7 @@ struct CastStorageParam : public dmlc::Parameter { }; inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); diff --git a/src/operator/nn/cast_storage.cc b/src/operator/nn/cast_storage.cc index 21c13e8fa564..c435146a730b 100644 --- a/src/operator/nn/cast_storage.cc +++ b/src/operator/nn/cast_storage.cc @@ -21,7 +21,7 @@ NNVM_REGISTER_OP(cast_storage) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FInferStorageType", CastStorageInferStorageType) +.set_attr("FInferStorageType", CastStorageInferStorageType) .set_attr("FCompute", IdentityCompute) .set_attr("FComputeEx", CastStorageComputeEx) .add_argument("data", "NDArray-or-Symbol", "The input.") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index c9e5b21470d9..37e073172d28 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -14,7 +14,7 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .set_attr("FCompute", BinaryCompute) .set_attr("FComputeEx", BinaryComputeEx) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}) -.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); +.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); // specialized gradient add function to do add to optimization // this must differ from elemwise_add to prevent add to optimization in forward pass. @@ -33,7 +33,7 @@ NNVM_REGISTER_OP(_backward_add) mshadow_op::identity>) .set_attr("FComputeEx", BinaryBackwardUseNoneEx) -.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); +.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); MXNET_OPERATOR_REGISTER_BINARY(_sub) .add_alias("_minus").add_alias("_Minus") diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index 372e94509a68..078d62b5f96e 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -126,7 +126,7 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_attr("FCompute", IdentityCompute) .set_attr("FComputeEx", IdentityLikeRhsComputeEx) .set_attr("FInferShape", ElemwiseShape<2, 1>) -.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) +.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index dfe53cf4614e..f55f7d8cf563 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -107,7 +107,7 @@ The gradient of an embedding matrix has the form of gradient vectors that are on }) .set_attr("FInferShape", SparseEmbeddingShape) .set_attr("FInferType", EmbeddingOpType) -.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) +.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; @@ -296,7 +296,7 @@ Example:: }) .set_attr("FInferShape", SparseRetainOpShape) .set_attr("FInferType", SparseRetainOpType) -.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) +.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) .set_attr("FComputeEx", SparseRetainOpForwardEx) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { @@ -310,7 +310,7 @@ NNVM_REGISTER_OP(_backward_sparse_retain) .set_num_inputs(2) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) +.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) .set_attr("FComputeEx", SparseRetainOpBackwardEx); } // namespace op diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index b2a67f73af78..213c2d2c3313 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -247,6 +247,7 @@ void SparseEmbeddingForwardEx(const nnvm::NodeAttrs& attrs, } inline bool SparseEmbeddingForwardStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); @@ -800,6 +801,7 @@ inline bool SparseRetainOpType(const nnvm::NodeAttrs& attrs, } inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); @@ -811,6 +813,7 @@ inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, } inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 8ba10bfd0e27..0aad920d641c 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -478,6 +478,7 @@ void DotBackward_(const nnvm::NodeAttrs& attrs, } inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); @@ -487,6 +488,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3U); diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 9ac998f02378..72d8aadbe90a 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -248,7 +248,7 @@ Example:: .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", Slice) .set_attr("FComputeEx", SliceEx) @@ -375,7 +375,7 @@ NNVM_REGISTER_OP(dot) }) .set_attr("FInferShape", DotShape) .set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FInferStorageType", DotForwardInferStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; @@ -392,7 +392,7 @@ NNVM_REGISTER_OP(_backward_dot) .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FInferStorageType", DotBackwardInferStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 9188dd9d933f..ceb965b43a72 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -129,19 +129,6 @@ def test_fc_infer_type(): for k, v in true_types.items(): assert arg_type_dict[k] == v -def check_infer_storage(v1, v2, v1_storage, v2_storage, out_chunk): - out = mx.symbol.elemwise_add(v1, v2) - arg_storage_types, out_storage_types, aux_storage_types = out.infer_storage_type(v1=v1_storage, v2=v2_storage) - assert len(out_storage_types) == 1 - assert out_storage_types[0] == out_chunk - -def test_elemwise_add_infer_storage_type(): - v1 = mx.symbol.Variable('v1') - v2 = mx.symbol.Variable('v2') - check_infer_storage(v1, v2, 'default', 'default', 'default') - check_infer_storage(v1, v2, 'default', 'row_sparse', 'default') - check_infer_storage(v1, v2, 'row_sparse', 'default', 'default') - check_infer_storage(v1, v2, 'row_sparse', 'row_sparse', 'row_sparse') if __name__ == "__main__": test_mlp2_infer_shape() @@ -152,4 +139,3 @@ def test_elemwise_add_infer_storage_type(): test_incomplete_infer_slicechannel() test_incomplete_infer_convolution() test_incomplete_infer_concat() - test_elemwise_add_infer_storage_type()