diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 06e39bfeb38b..79c92edfa0a1 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1056,6 +1056,28 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, */ MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char **name); + +/*! + * \brief Get the input symbols of the graph. + * \param sym The graph. + * \param inputs The input symbols of the graph. + * \param input_size the number of input symbols returned. + */ +MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs, + int *input_size); + +/*! + * \brief Cut a subgraph whose nodes are marked with a subgraph attribute. + * The input graph will be modified. A variable node will be created for each + * edge that connects to nodes outside the subgraph. The outside nodes that + * connect to the subgraph will be returned. + * \param sym The graph. + * \param inputs The nodes that connect to the subgraph. + * \param input_size The number of such nodes. + */ +MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs, + int *input_size); + /*! * \brief Get the detailed information about atomic symbol. * \param creator the AtomicSymbolCreator. diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e243eb71c477..897b5e882aaa 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -693,6 +693,10 @@ class NDArray { NDArray MKLDNNDataReshape(const TShape &shape) const; #endif + const nnvm::NodeEntry &entry() const { + return entry_; + } + /*! * \brief Save list of ndarray into the Stream.x * \param fo The stream of output. diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 3969d8445be1..23a318464f15 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -64,8 +64,10 @@ enum OpReqType { * \sa Resource */ struct OpContext { + /*! \brief whether there is a backward phase to compute gradients. */ + bool need_grad; /*! \brief whether it is training phase */ - int is_train; + bool is_train; /*! \brief RunContext related resources */ RunContext run_ctx; /*! \brief the callback when operation completes, used by asynchronize ops */ diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index ba402e6f3f8d..d68698f71d66 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -21,6 +21,8 @@ import math from ..context import current_context from ..random import uniform +from ..base import _as_list +from . import ndarray try: from .gen_contrib import * except ImportError: @@ -96,3 +98,96 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): expected_count_sampled = expected_prob_sampled * num_sampled return sampled_classes, expected_count_true, expected_count_sampled # pylint: enable=line-too-long + +def foreach(body, data, init_states): + """Run a for loop with user-defined computation over NDArrays on dimension 0. + + This operator simulates a for loop and body has the computation for an iteration + of the for loop. It runs the computation in body on each slice from the input + NDArrays. + + body takes two arguments as input and outputs a tuple of two elements, + as illustrated below: + + out, states = body(data1, states) + + data1 can be either an NDArray or a list of NDArrays. If data is an NDArray, + data1 is an NDArray. Otherwise, data1 is a list of NDArrays and has the same + size as data. states is a list of NDArrays and have the same size as init_states. + Similarly, out can be either an NDArray or a list of NDArrays, which are concatenated + as the first output of foreach; states from the last execution of body + are the second output of foreach. + + The computation done by this operator is equivalent to the pseudo code below + when the input data is NDArray: + + states = init_states + outs = [] + for i in data.shape[0]: + s = data[i] + out, states = body(s, states) + outs.append(out) + outs = stack(*outs) + + + Parameters + ---------- + body : a Python function. + Define computation in an iteration. + data: an NDArray or a list of NDArrays. + The input data. + init_states: an NDArray or a list of NDArrays. + The initial values of the loop states. + name: string. + The name of the operator. + + Returns + ------- + outputs: an NDArray or a list of NDArrays. + The output data concatenated from the output of all iterations. + states: a list of NDArrays. + The loop states in the last iteration. + + Examples + -------- + >>> step = lambda data, states: (data + states[0], [states[0] * 2]) + >>> data = mx.nd.random.uniform(shape=(2, 10)) + >>> states = [mx.nd.random.uniform(shape=(10))] + >>> outs, states = mx.nd.contrib.foreach(step, data, states) + """ + + def check_input(inputs, in_type, msg): + is_NDArray_or_list = True + if isinstance(inputs, list): + for i in inputs: + if not isinstance(i, in_type): + is_NDArray_or_list = False + break + else: + is_NDArray_or_list = isinstance(inputs, in_type) + assert is_NDArray_or_list, msg + + check_input(data, ndarray.NDArray, "data should be an NDArray or a list of NDArrays") + check_input(init_states, ndarray.NDArray, + "init_states should be an NDArray or a list of NDArrays") + + not_data_list = isinstance(data, ndarray.NDArray) + not_state_list = isinstance(init_states, ndarray.NDArray) + num_iters = data.shape[0] if not_data_list else data[0].shape[0] + states = init_states + outputs = [] + for i in range(num_iters): + if not_data_list: + eles = data[i] + else: + eles = [d[i] for d in data] + outs, states = body(eles, states) + outs = _as_list(outs) + outputs.append(outs) + outputs = zip(*outputs) + for j, out in enumerate(outputs): + outputs[j] = ndarray.op.stack(*out) + + if not_data_list: + outputs = outputs[0] + return (outputs, states) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 83e90e687327..a1a1d23bbbe5 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -19,6 +19,9 @@ # pylint: disable=wildcard-import, unused-wildcard-import """Contrib Symbol API of MXNet.""" import math +import ctypes +import re + from .random import uniform from .symbol import Symbol try: @@ -26,6 +29,11 @@ except ImportError: pass +from . import symbol +from ..base import _LIB, c_array, check_call +from ..base import SymbolHandle, _as_list +from ..attribute import AttrScope + __all__ = ["rand_zipfian"] def rand_zipfian(true_classes, num_sampled, range_max): @@ -91,3 +99,196 @@ def rand_zipfian(true_classes, num_sampled, range_max): expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range expected_count_sampled = expected_prob_sampled * num_sampled return sampled_classes, expected_count_true, expected_count_sampled + +def _get_graph_inputs(subg): + num_handles = ctypes.c_int(1000) + handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)]) + check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles, ctypes.byref(num_handles))) + + syms = [] + for i in range(num_handles.value): + s = Symbol(handles[i]) + syms.append(s) + return syms + +def _cut_subgraph(subg): + num_handles = ctypes.c_int(1000) + handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)]) + check_call(_LIB.MXSymbolCutSubgraph(subg.handle, handles, ctypes.byref(num_handles))) + + syms = [] + for i in range(num_handles.value): + s = Symbol(handles[i]) + syms.append(s) + return syms + +def foreach(body, data, init_states, name="foreach"): + """Run a for loop with user-defined computation over Symbols on dimension 0. + + This operator simulates a for loop and body has the computation for an iteration + of the for loop. It runs the computation in body on each slice from the input + NDArrays. + + body takes two arguments as input and outputs a tuple of two elements, + as illustrated below: + + out, states = body(data1, states) + + data1 can be either a symbol or a list of symbols. If data is a symbol, + data1 is a symbol. Otherwise, data1 is a list of symbols and has the same + size as data. states is a list of symbols and have the same size as init_states. + Similarly, out can be either a symbol or a list of symbols, which are concatenated + as the first output of foreach; states from the last execution of body + are the second output of foreach. + + The computation done by this operator is equivalent to the pseudo code below + when the input data is NDArray: + + states = init_states + outs = [] + for i in data.shape[0]: + s = data[i] + out, states = body(s, states) + outs.append(out) + outs = stack(*outs) + + + Parameters + ---------- + body : a Python function. + Define computation in an iteration. + data: a symbol or a list of symbols. + The input data. + init_states: a symbol or a list of symbols. + The initial values of the loop states. + name: string. + The name of the operator. + + Returns + ------- + outputs: a Symbol or a list of Symbols. + The output data concatenated from the output of all iterations. + states: a list of Symbols. + The loop states in the last iteration. + + Examples + -------- + >>> step = lambda data, states: (data + states[0], [states[0] * 2]) + >>> data = mx.sym.var('data') + >>> states = [mx.sym.var('state')] + >>> outs, states = mx.sym.contrib.foreach(step, data, states) + """ + + def check_data(inputs, in_type, msg): + is_NDArray_or_list = True + if isinstance(inputs, list): + for i in inputs: + if not isinstance(i, in_type): + is_NDArray_or_list = False + break + else: + is_NDArray_or_list = isinstance(inputs, in_type) + assert is_NDArray_or_list, msg + + check_data(data, symbol.Symbol, "data should be an NDArray or a list of NDArrays") + check_data(init_states, symbol.Symbol, + "init_states should be an NDArray or a list of NDArrays") + not_state_list = isinstance(init_states, symbol.Symbol) + + # TODO(zhengda) If the input python function references to the symbols outside + # the python function, we need to prune the computation graph constructed from + # the function. One way of doing it is to mark the nodes in the computation graph + # with AttrScope and prune the nodes without the special attribute. + with AttrScope(subgraph_name=name): + if isinstance(data, list): + in_eles = [symbol.var(sym.name) for sym in data] + else: + in_eles = symbol.var(data.name) + if isinstance(init_states, list): + states = [symbol.var(s.name) for s in init_states] + else: + states = symbol.var(init_states.name) + sym_out, sym_states = body(in_eles, states) + + check_data(sym_out, symbol.Symbol, + "the output should be an NDArray or a list of NDArrays") + check_data(sym_states, symbol.Symbol, + "the output states should be an NDArray or a list of NDArrays") + if isinstance(sym_states, list): + assert isinstance(init_states, list) and len(sym_states) == len(init_states), \ + "the number of output states (%d) should be the same as input states (%d)" \ + % (len(sym_states), len(init_states)) + + if isinstance(sym_out, list): + flat_out = sym_out + else: + flat_out = [sym_out] + num_out_data = len(flat_out) + if isinstance(sym_states, list): + for s in sym_states: + # There is a problem if the outputs are the same as the inputs + # or the first output. By calling identity, we can make sure that + # all symbols will refer to different NDArrays. + flat_out.append(symbol.op.identity(s)) + else: + flat_out.append(symbol.op.identity(sym_states)) + g = symbol.Group(flat_out) + + cut_syms = _cut_subgraph(g) + input_syms = _get_graph_inputs(g) + + # Here we need to find out how the input symbols are ordered as well as + # where the loop states are located in the list of inputs. + + # This dict contains the symbols of the subgraph. + input_syms = {sym.name:sym for sym in input_syms} + gin_names = input_syms.keys() + # This array contains the symbols for the inputs of foreach. + # They are ordered according to the inputs of the subgraph. + states_map = {sym.name:sym for sym in init_states} + state_names = states_map.keys() + data_syms = _as_list(data) + data_map = {sym.name:sym for sym in data_syms} + data_names = data_map.keys() + + ordered_ins = [] + in_state_locs = [] + in_data_locs = [] + for in_name in g.list_inputs(): + assert in_name in gin_names, "The input variable %s can't be found in graph inputs: %s" \ + % (in_name, str(gin_names)) + if in_name in state_names: + ordered_ins.append(states_map[in_name]) + in_state_locs.append(len(ordered_ins) - 1) + elif in_name in data_names: + ordered_ins.append(data_map[in_name]) + in_data_locs.append(len(ordered_ins) - 1) + else: + # The remaining inputs are the ones cut from the original graph. + # The names of these variable nodes contain the index in cut_syms. + m = re.search(r'\d+$', in_name) + idx = int(m.group()) if m else None + assert idx < len(cut_syms) + ordered_ins.append(cut_syms[idx]) + + num_outputs = len(flat_out) + num_states = len(state_names) + ret = symbol._internal._foreach(g, *ordered_ins, num_outputs=num_outputs, + num_out_data=num_out_data, in_state_locs=in_state_locs, + in_data_locs=in_data_locs) + if num_outputs - num_states > 1: + outs = [] + for i in range(num_outputs - num_states): + outs.append(ret[i]) + else: + outs = ret[0] + states = [] + for i in range(num_states): + states.append(ret[num_outputs - num_states + i]) + + if not_state_list: + # If there is only one input state, there should be only one output state. + assert len(states) == 1 + states = states[0] + + return (outs, states) diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 4666b6adf0c3..030ab432228b 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -38,10 +38,11 @@ void RegisterLegacyOpProp(); void RegisterLegacyNDFunc(); } const std::vector kHiddenKeys = { - "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage" + "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage", "subgraph_name" }; const std::vector kReplacedHiddenKeys = { - "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__", "__mirror_stage__" + "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__", "__mirror_stage__", + "subgraph_name" }; const char *kNamespaceSeparator = "$"; @@ -344,6 +345,75 @@ int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, API_END(); } +namespace mxnet { + +extern std::vector GetInputSymbols(const nnvm::Symbol &sym); +extern bool CutGraph(const std::vector &input_entries, + const std::string &in_name_prefix, bool skip_var, + std::vector *orig_entries, + std::vector *new_var_names); + +} + +int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) { + API_BEGIN(); + nnvm::Symbol *s = static_cast(sym); + size_t max_input_size = *input_size; + std::vector input_syms = mxnet::GetInputSymbols(*s); + CHECK(input_syms.size() <= max_input_size); + *input_size = input_syms.size(); + memcpy(input_arr, input_syms.data(), sizeof(*input_arr) * input_syms.size()); + API_END_HANDLE_ERROR(); +} + +int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, + int *input_size) { + // Given a graph, we want to fetch the nodes that have been marked as part of + // a subgraph. + API_BEGIN(); + nnvm::Symbol *s = static_cast(sym); + size_t max_input_size = *input_size; + std::string subg_attr = "__subgraph_name__"; + auto out_node = s->outputs[0].node; + auto it = out_node->attrs.dict.find(subg_attr); + if (it != out_node->attrs.dict.end()) { + std::string subg_name = it->second; + std::vector input_entries; + DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries] + (nnvm::NodePtr n) { + // If the node itself isn't in the subgraph, we ignore it. + auto it = n->attrs.dict.find(subg_attr); + if (it == n->attrs.dict.end() || it->second != subg_name) + return; + + // We search for nodes whose node entries aren't in the subgraph. + for (size_t j = 0; j < n->inputs.size(); j++) { + auto in_node = n->inputs[j].node; + auto it = in_node->attrs.dict.find(subg_attr); + if (it == in_node->attrs.dict.end() || it->second != subg_name) + input_entries.push_back(&n->inputs[j]); + } + }); + + std::vector orig_entries; + std::vector new_var_names; + CutGraph(input_entries, subg_name + "_var", false, &orig_entries, &new_var_names); + + std::vector input_syms(orig_entries.size()); + for (size_t i = 0; i < input_syms.size(); i++) { + input_syms[i] = new nnvm::Symbol(); + input_syms[i]->outputs.push_back(orig_entries[i]); + } + CHECK(input_syms.size() <= max_input_size); + *input_size = input_syms.size(); + memcpy(input_symbols, input_syms.data(), sizeof(*input_symbols) * input_syms.size()); + } else { + *input_size = 0; + } + + API_END_HANDLE_ERROR(); +} + int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 697e4869a049..b90aa83099ae 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -126,6 +126,10 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { PostFCompute(is_gpu); } + bool HasSubgraph() const override { + return !attrs_.subgraphs.empty(); + } + ExecType exec_type() const override { return exec_type_; } @@ -134,15 +138,17 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { return state_.get_var(); } - explicit StatefulComputeExecutor(const OpStatePtr& state, + explicit StatefulComputeExecutor(const NodeAttrs& attrs, + const OpStatePtr& state, const FStatefulCompute& fcompute, ExecType exec_type, const std::vector &mutate_idx) - : StorageFallbackOpExecutor(mutate_idx), + : StorageFallbackOpExecutor(mutate_idx), attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: friend Graph AttachOpExecs(Graph g); + NodeAttrs attrs_; OpStatePtr state_; FStatefulCompute fcompute_; ExecType exec_type_; @@ -160,6 +166,10 @@ class StatefulComputeExExecutor : public OpExecutor { fcompute_(state_, op_ctx, in_array, req, out_array); } + bool HasSubgraph() const override { + return !attrs_.subgraphs.empty(); + } + void Setup() override {} ExecType exec_type() const override { @@ -170,13 +180,14 @@ class StatefulComputeExExecutor : public OpExecutor { return state_.get_var(); } - explicit StatefulComputeExExecutor(const OpStatePtr& state, + explicit StatefulComputeExExecutor(const NodeAttrs& attrs, const OpStatePtr& state, const FStatefulComputeEx& fcompute, ExecType exec_type) - : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} + : attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: friend Graph AttachOpExecs(Graph g); + NodeAttrs attrs_; OpStatePtr state_; FStatefulComputeEx fcompute_; ExecType exec_type_; @@ -201,6 +212,10 @@ class FComputeExecutor : public StorageFallbackOpExecutor { return exec_type_; } + bool HasSubgraph() const override { + return !attrs_.subgraphs.empty(); + } + explicit FComputeExecutor(const NodeAttrs& attrs, FCompute fcompute, ExecType exec_type, const std::vector &mutate_idx) : StorageFallbackOpExecutor(mutate_idx), @@ -226,6 +241,10 @@ class FComputeExExecutor : public OpExecutor { void Setup() override {} + bool HasSubgraph() const override { + return !attrs_.subgraphs.empty(); + } + ExecType exec_type() const override { return exec_type_; } @@ -289,15 +308,17 @@ Graph AttachOpExecs(Graph g) { op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared(state, fcompute_ex, exec_type); + ret[i] = std::make_shared(inode.source->attrs, state, + fcompute_ex, exec_type); } else { FStatefulCompute fcompute = common::GetFCompute( op, "FStatefulCompute", vctx[i]); CHECK(fcompute != nullptr) << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; - ret[i] = std::make_shared(state, fcompute, - exec_type, mutate_index); + ret[i] = std::make_shared(inode.source->attrs, state, + fcompute, exec_type, + mutate_index); } } else if (is_layer_backward.get(op, false)) { CHECK_GE(inode.control_deps.size(), 1); @@ -308,7 +329,7 @@ Graph AttachOpExecs(Graph g) { op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared( + ret[i] = std::make_shared(inode.source->attrs, dynamic_cast(ret[fwd_id].get())->state_, fcompute_ex, exec_type); } else { @@ -317,7 +338,7 @@ Graph AttachOpExecs(Graph g) { CHECK(fcompute != nullptr) << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; - ret[i] = std::make_shared( + ret[i] = std::make_shared(inode.source->attrs, dynamic_cast(ret[fwd_id].get())->state_, fcompute, exec_type, mutate_index); } diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 99b1b162eaee..f49fcf61db21 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -64,6 +64,7 @@ class OpExecutor { OpContext op_ctx; /*! \brief virtual destructor */ virtual ~OpExecutor() {} + virtual bool HasSubgraph() const = 0; /*! * \brief Setup the executor for given NDArray member * this can be called multiple times if NDArray changed during reshape. diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7a15f6c931c7..ca06a12a5a0e 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -39,6 +39,7 @@ namespace exec { GraphExecutor::GraphExecutor() { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); + need_grad_ = false; } GraphExecutor::~GraphExecutor() { @@ -257,11 +258,11 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, nnvm::Graph g; g.outputs = symbol.outputs; - bool need_grad = false; + need_grad_ = false; for (OpReqType req : grad_req_types) { - if (req != kNullOp) need_grad = true; + if (req != kNullOp) need_grad_ = true; } - if (!need_grad) return g; + if (!need_grad_) return g; for (size_t i = 0; i < g.outputs.size(); ++i) { NodeEntry ngrad{nnvm::Node::Create(), 0, 0}; head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i])); @@ -1378,7 +1379,11 @@ void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { // check if the segment relies on external input, or exceeds maxinum number of node, // or requires async ops if (node->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { + op_node.exec->exec_type() != ExecType::kSync || + // If the node has a subgraph, we shouldn't add it to the segment. + // We'll execute the node separately from other nodes. + // CreateCachedSegOpr creates a segment excluding nodes with subgraphs. + op_node.exec->HasSubgraph()) { // create a new segment for the previous nodes if the current one cannot be bulked cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); topo_start = nid + 1; @@ -1403,7 +1408,11 @@ void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { continue; } if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { + op_node.exec->exec_type() != ExecType::kSync || + // If the node has a subgraph, we shouldn't add it to the segment. + // We'll execute the node separately from other nodes. + // CreateCachedSegOpr creates a segment excluding nodes with subgraphs. + op_node.exec->HasSubgraph()) { cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); topo_start = nid + 1; } else { @@ -1437,7 +1446,11 @@ void GraphExecutor::BulkInferenceOpSegs() { // Variables do not need to be segmented at inference time. if (node->is_variable()) continue; - if (op_node.exec->exec_type() != ExecType::kSync) { + if (op_node.exec->exec_type() != ExecType::kSync || + // If the node has a subgraph, we shouldn't add it to the segment. + // We'll execute the node separately from other nodes. + // CreateCachedSegOpr creates a segment excluding nodes with subgraphs. + op_node.exec->HasSubgraph()) { cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); topo_start = nid + 1; } @@ -1480,6 +1493,7 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; opnode.exec->op_ctx.is_train = is_train; + opnode.exec->op_ctx.need_grad = need_grad_; } // Push Ops @@ -1498,11 +1512,15 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { OpNode& opnode = op_nodes_[nid]; if (op_nodes_[nid].skip_exec_node) continue; opnode.exec->op_ctx.is_train = is_train; + opnode.exec->op_ctx.need_grad = need_grad_; if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) { CHECK_EQ(inode.inputs.size(), 1U); CHECK_EQ(opnode.exec->in_array.size(), 1U); CHECK_EQ(opnode.exec->out_array.size(), 1U); CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0])); + } else if (opnode.exec->HasSubgraph()) { + // If the node contains a subgraph, we can't execute it in the engine. + opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false); } else if (opnode.cached_opr != nullptr) { bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); @@ -1537,6 +1555,9 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, OpNode& op_node = op_nodes_[nid]; if (op_node.skip_exec_node) continue; if (inode.source->is_variable()) continue; + // We shouldn't add control flow operators to a segment. + // We can't execute these operators in the engine. + if (op_node.exec->HasSubgraph()) return ret; if (op_node.exec->exec_type() != ExecType::kSync) { return ret; } diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index bcde41d508eb..fa2a156d3d76 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -203,6 +203,8 @@ class GraphExecutor : public Executor { // perform bulking and segmentation on a training graph void BulkTrainingOpSegs(size_t total_num_nodes); + // indicate whether there is a backward graph for gradients. + bool need_grad_; // internal graph nnvm::Graph graph_; // operator node diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index d7bb37b7cfef..1135c0d2d416 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -359,6 +359,7 @@ inline void PushFCompute(const FCompute& fn, static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; CHECK(exec_type == ExecType::kSync); std::vector inputs, outputs; @@ -379,7 +380,7 @@ inline void PushFCompute(const FCompute& fn, &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); // setup context - OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; + OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; bool is_gpu = ctx.dev_mask() == gpu::kDevMask; // pre-fcompute fallback, cast to default storage type CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu); @@ -406,11 +407,12 @@ inline void PushFComputeEx(const FComputeEx& fn, static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { - OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; + OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(outputs, req); #endif @@ -445,6 +447,7 @@ inline void PushOperator(const OpStatePtr& state, static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); @@ -456,17 +459,23 @@ inline void PushOperator(const OpStatePtr& state, if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) { const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { - OpContext opctx{is_train, rctx, on_complete, requested}; + OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(outputs, req); #endif fcompute_ex(state, opctx, inputs, req, outputs); - if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) { + if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync + && rctx.get_stream()) { rctx.get_stream()->Wait(); } }; - if (exec_type == ExecType::kSync) { + // For operators with subgraphs, we need to invoke them in the main thread + // instead of the threaded engine. + if (!attrs.subgraphs.empty()) { + RunContext rctx{ctx, nullptr}; + run(rctx, engine::CallbackOnComplete()); + } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, ctx, read_vars, write_vars, FnProperty::kNormal, 0, @@ -483,7 +492,7 @@ inline void PushOperator(const OpStatePtr& state, << "for stateful operator " << op->name; const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { - OpContext opctx{is_train, rctx, on_complete, requested}; + OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; std::vector input_blobs, output_blobs; // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays @@ -505,12 +514,16 @@ inline void PushOperator(const OpStatePtr& state, fcompute(state, opctx, input_blobs, tmp_req, output_blobs); // post-fcompute fallback, cast to original storage type, if necessary CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - if (is_gpu && exec_type == ExecType::kSync) { + if (is_gpu && exec_type == ExecType::kSync + && rctx.get_stream()) { rctx.get_stream()->Wait(); } }; - if (exec_type == ExecType::kSync) { + if (!attrs.subgraphs.empty()) { + RunContext rctx{ctx, nullptr}; + run(rctx, engine::CallbackOnComplete()); + } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index d87e8bc95ea5..764711f020ff 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -200,6 +200,7 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const { ret.ptr_->delay_alloc = false; ret.ptr_->static_data = true; ret.byte_offset_ = byte_offset_; + ret.reuse_ = false; return ret; } } @@ -217,6 +218,7 @@ NDArray NDArray::Reshape(const TShape &shape) const { // Otherwise, reshape only works on the default layout. CHECK_EQ(storage_type(), kDefaultStorage); ret.shape_ = shape; + ret.reuse_ = false; return ret; } @@ -249,6 +251,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { ret.byte_offset_ += begin * length * sizeof(DType); }); + ret.reuse_ = false; ret.shape_[0] = end - begin; return ret; } @@ -555,6 +558,7 @@ NDArray NDArray::Reorder2Default() const { // reshape as needed ret.shape_ = shape_; ret.byte_offset_ = byte_offset_; + ret.reuse_ = false; return ret; } @@ -584,39 +588,39 @@ void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc) const mkldnn::memory *NDArray::GetMKLDNNData() const { CHECK(storage_type() == kDefaultStorage); + bool is_view = IsView(); if (IsMKLDNNData()) { // If this array uses MKLDNN layout, we have to make sure it's not a view. // Otherwise, we'll have to change the layout inside the array. - CHECK(!IsView()); + CHECK(!is_view); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); // If this array uses MKLDNN format, we should return now. Otherwise, // SetMKLMem may mess up mkl_mem_. return ptr_->mkl_mem_->GetRaw(); - } - ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_); - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); - if (IsView()) { - mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc(); - // Sliced array must use the default layout. - CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format); - void *off_addr = static_cast(ptr_->mkl_mem_->GetDataHandle()) - + byte_offset_; - + } else if (is_view) { + // If this is a view, we can't create a MKLDNN memory for the chunk + // because we don't have the complete data type and shape information for + // the chunk. + void *off_addr = static_cast(ptr_->shandle.dptr) + byte_offset_; // Create the primitive desc for the new mkldnn memory. mkldnn::memory::dims dims(shape().ndim()); for (size_t i = 0; i < dims.size(); i++) dims[i] = shape()[i]; mkldnn::memory::format cpp_format = static_cast( GetDefaultFormat(shape().ndim())); - mkldnn::memory::data_type cpp_type = static_cast( - pd.desc().data.data_type); + mkldnn::memory::data_type cpp_type = get_mkldnn_type(dtype_); mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); - mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine()); + mkldnn::memory::primitive_desc new_pd(data_md, + CpuEngine::Get()->get_engine()); std::shared_ptr ret(new mkldnn::memory(new_pd, off_addr)); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } else { + // If this isn't a view, we can create a MKLDNN memory and store it in the + // chunk. + ptr_->SetMKLMem(shape_, dtype_); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return ptr_->mkl_mem_->GetRaw(); } } @@ -637,10 +641,9 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { MKLDNNStream *stream = MKLDNNStream::Get(); // If this array uses MKLDNN layout, we have to make sure it's not a view. // Otherwise, we'll have to change the layout inside the array. - if (IsMKLDNNData()) - CHECK(!IsView()); - ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, - dtype_); + + CHECK(!IsView()); + ptr_->SetMKLMem(shape_, dtype_); stream->RegisterMem(ptr_->mkl_mem_->GetMem()); mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc(); mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc(); @@ -713,9 +716,6 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p mkldnn_memory_format_t format); mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) { - // This array shouldn't be a view. - CHECK(!IsView()); - if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; @@ -726,10 +726,26 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc()); // If the required format is a default format, we don't need to worry about the shape. // If the shape isn't the same, it actually implicitly reshapes data. - if (required_format == def_format) { + if (required_format == def_format && !IsView()) { ptr_->SetMKLMem(shape_, dtype_); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); + } else if (required_format == def_format) { + ptr_->CheckAndAlloc(); + CHECK(ptr_->shandle.dptr); + // When this is a view and a user wants the default layout, we can simply + // create a new mkldnn memory that points to the right memory. + std::shared_ptr mem(new mkldnn::memory( + desc, ptr_->shandle.dptr + byte_offset_)); + MKLDNNStream::Get()->RegisterMem(mem); + return mem.get(); + } else if (IsView()) { + // If this is a view and a user wants to write data to it with special + // a MKLDNN format, we should reorder the data in the array and return NULL. + // In this way, the user will create a new NDArray for the special format + // and copy data back. + ptr_->Reorder2Default(); + return nullptr; } if (ptr_->mkl_mem_) @@ -1160,7 +1176,8 @@ void CopyFromToImpl(const NDArray& from, const NDArray& to, const Context to_ctx = to.ctx(); bool is_train = Imperative::Get()->is_training(); - OpContext opctx{is_train, + OpContext opctx{Imperative::Get()->is_recording(), + is_train, rctx, engine::CallbackOnComplete(), requested}; diff --git a/src/nnvm/graph_editor.cc b/src/nnvm/graph_editor.cc new file mode 100644 index 000000000000..98c99e2425df --- /dev/null +++ b/src/nnvm/graph_editor.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_editor.cc + * The functions in this file edit an NNVM graph. Potentially, + * these functions should be moved to NNVM in the future. + */ + +#include +#include +#include + +namespace nnvm { +NodePtr CreateVariableNode(const std::string& name); +} + +namespace mxnet { + +/* + * Given a computation graph, this function finds the input nodes of the graph + * and create symbols for the input nodes. It returns the input symbols. + */ +std::vector GetInputSymbols(const nnvm::Symbol &sym) { + nnvm::Graph g; + std::vector input_syms; + g.outputs = sym.outputs; + const nnvm::IndexedGraph& idx = g.indexed_graph(); + // Go through all nodes and return the ones representing variables. + for (size_t i = 0; i < idx.num_nodes(); i++) { + const nnvm::Node &n = *idx[i].source; + for (const nnvm::NodeEntry &e : n.inputs) { + auto p = e.node; + if (p->is_variable()) { + nnvm::Symbol *s = new nnvm::Symbol(); + s->outputs.push_back(e); + input_syms.push_back(s); + } + } + } + return input_syms; +} + +/* + * Given a computation graph and a set of input node entries, this function cuts + * the node entries and creates new variable nodes as the input nodes of the + * subgraph. It returns the nodes that connect to the subgraph directly and + * the names of the new variable nodes. + */ +bool CutGraph(const std::vector &input_entries, + const std::string &in_name_prefix, bool skip_var, + std::vector *orig_entries, + std::vector *new_var_names) { + orig_entries->reserve(input_entries.size()); + for (size_t i = 0; i < input_entries.size(); i++) { + nnvm::NodeEntry *e = input_entries[i]; + // If the node is a variable itself, we may want to skip the node. + if (e->node->is_variable() && skip_var) + continue; + + orig_entries->push_back(*e); + new_var_names->push_back(in_name_prefix + std::to_string(i)); + nnvm::NodePtr n = nnvm::CreateVariableNode(new_var_names->back()); + *e = nnvm::NodeEntry{n, 0, 0}; + } + return true; +} + +} diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc new file mode 100644 index 000000000000..c42aca0944d9 --- /dev/null +++ b/src/operator/control_flow.cc @@ -0,0 +1,419 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./elemwise_op_common.h" +#include "../imperative/imperative_utils.h" +#include "./subgraph_op_common.h" + +namespace mxnet { +namespace op { + +struct ForeachParam : public dmlc::Parameter { + int num_args; + int dim; + int num_outputs; + int num_out_data; + nnvm::Tuple in_state_locs; + nnvm::Tuple in_data_locs; + DMLC_DECLARE_PARAMETER(ForeachParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) + .describe("Number of inputs."); + DMLC_DECLARE_FIELD(dim).set_default(1) + .describe("the dimension of the input array to iterate."); + DMLC_DECLARE_FIELD(num_outputs) + .describe("The number of outputs of the subgraph."); + DMLC_DECLARE_FIELD(num_out_data) + .describe("The number of output data of the subgraph."); + DMLC_DECLARE_FIELD(in_state_locs) + .describe("The locations of loop states among the inputs."); + DMLC_DECLARE_FIELD(in_data_locs) + .describe("The locations of input data among the inputs."); + } +}; // struct ForeachParam + +DMLC_REGISTER_PARAMETER(ForeachParam); + +class ForeachState: public LoopState { + public: + ForeachParam params; + + ForeachState(const Symbol &g, const ForeachParam ¶ms) : LoopState(g) { + this->params = params; + } +}; + +static void ForeachComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ForeachState &state = state_ptr.get_state(); + const ForeachParam& params = state.params; + size_t iter_dim = 0; + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_GT(params.in_data_locs.ndim(), 0); + size_t loc0 = params.in_data_locs[0]; + size_t len = inputs[loc0].shape()[iter_dim]; + for (size_t i = 1; i < params.in_data_locs.ndim(); i++) { + size_t loc = params.in_data_locs[i]; + CHECK_EQ(inputs[loc].shape()[iter_dim], len); + } + for (size_t i = 0; i < (size_t) params.num_out_data; i++) + CHECK_EQ(len, outputs[i].shape()[iter_dim]); + for (const auto &arr : outputs) + CHECK_EQ(arr.storage_type(), kDefaultStorage) + << "The for operator doesn't support the sparse format"; + + // Initialize the outputs of the subgraph is a little trickier. + // The states from the previous iteration are used as the inputs of the next + // iteration, so I have to maintain two arrays, so the inputs and outputs + // of the subgraph share the same memory. + std::vector subg_outputs1(outputs.size()); + std::vector subg_outputs2(outputs.size()); + std::vector *subg_outputs[2]{&subg_outputs1, &subg_outputs2}; + // If the length is an odd number, the last iteration will use the first set + // of outputs. In this way, we don't need to copy the results from the + // subgraph to the final outputs of the loop. + if (len % 2 == 1) { + for (size_t i = 1; i < subg_outputs1.size(); i++) { + subg_outputs1[i] = outputs[i]; + subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, + outputs[i].dtype()); + } + } else { + // Otherwise, we'll use the second set of outputs. + for (size_t i = 1; i < subg_outputs1.size(); i++) { + subg_outputs1[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, + outputs[i].dtype()); + subg_outputs2[i] = outputs[i]; + } + } + + // Initialize the inputs for the subgraph. + // In each iteration, we need to update the subgraph inputs for input data + // and the loop states. This initialization helps to get the read-only + // arrays in the loop. + std::vector subg_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + // These are the initial states. + subg_inputs[i] = inputs[i]; + } + + // Here we iterate over the first dimension of the first input array. + for (size_t i = 0; i < len; i++) { + // Initialize outputs for the subgraph. + std::vector *subg_out_curr = subg_outputs[i % 2]; + std::vector *subg_out_prev = subg_outputs[(i + 1) % 2]; + for (int j = 0; j < params.num_out_data; j++) + (*subg_out_curr)[j] = outputs[j].At(i); + // When recording for backward computation, we should make sure + // that output arrays are actually different in each iteration. + if (ctx.need_grad && i < len - 1) { + for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++) + (*subg_out_curr)[j] = NDArray(outputs[j].shape(), outputs[j].ctx(), + true, outputs[j].dtype()); + } else if (ctx.need_grad && i == len - 1) { + // For the last iteration, we need to write data to the output array + // directly. + for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++) + (*subg_out_curr)[j] = outputs[j]; + } + + // Initialize inputs for the subgraph. + // Get a slice from the input data arrays. + for (size_t j = 0; j < params.in_data_locs.ndim(); j++) { + size_t loc = params.in_data_locs[j]; + subg_inputs[loc] = inputs[loc].At(i); + } + // For the rest of the iterations, the rest of the arguments are the outputs + // from the previous iteration. + if (i > 0) { + for (size_t j = params.num_out_data; j < subg_out_prev->size(); j++) { + size_t idx = j - params.num_out_data; + CHECK_LT(params.in_state_locs[idx], subg_inputs.size()); + subg_inputs[params.in_state_locs[idx]] = (*subg_out_prev)[j]; + } + } + + state.Forward(subg_inputs, req, *subg_out_curr, ctx.need_grad); + // We need to wait for the iteration to complete before executing + // the next one or return from the loop. In this way, we can reuse + // the memory in the subgraph. + for (size_t j = 0; j < subg_out_curr->size(); j++) { + (*subg_out_curr)[j].WaitToRead(); + } + } +} + +static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ForeachState &state = state_ptr.get_state(); + const ForeachParam& params = state.params; + CHECK_EQ(outputs.size(), (size_t) params.num_args - 1); + CHECK_GT(params.in_data_locs.ndim(), 0); + for (const auto &arr : outputs) + CHECK_EQ(arr.storage_type(), kDefaultStorage) + << "The for operator doesn't support the sparse format"; + size_t iter_dim = 0; + std::unordered_set in_data_locs(params.in_data_locs.begin(), + params.in_data_locs.end()); + std::unordered_set in_state_locs(params.in_state_locs.begin(), + params.in_state_locs.end()); + // The inputs contain out gradients, inputs and outputs. + int len = inputs[0].shape()[iter_dim]; + size_t num_output_data = params.num_out_data; + + // In backward computation, we need to run iterations from backwards. + std::vector ograds(params.num_outputs); + std::vector igrads(outputs.size()); + for (size_t i = num_output_data; i < ograds.size(); i++) + ograds[i] = inputs[i]; + std::vector iter_req(req.size()); + for (auto r : req) + CHECK_NE(r, kWriteInplace); + for (int iter_num = len - 1; iter_num >= 0; iter_num--) { + for (int i = 0; i < params.num_out_data; i++) + ograds[i] = inputs[i].At(iter_num); + + // There are three types of arrays in igrads. + // * data gradients. + // * loop variable gradients. + // * read-only variable gradients. + // These are the input data gradients. + for (size_t i = 0; i < igrads.size(); i++) { + // data gradients. + if (in_data_locs.count(i)) { + igrads[i] = outputs[i].At(iter_num); + iter_req[i] = req[i]; + continue; + } + + bool in_state = in_state_locs.count(i); + if (iter_num != 0 && in_state) { + // For state gradients, we need to allocate new NDArrays + // because intermediate state gradients won't be returned to the users. + igrads[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), + true, outputs[i].dtype()); + } else { + igrads[i] = outputs[i]; + } + if (in_state) + // For the first iteration, we need to use the request provided by + // the user to write state gradients to the outputs. + iter_req[i] = iter_num != 0 ? kWriteTo : req[i]; + else + // For all read-only variable gradients, we need to use the request + // provided by the user in the last iteration and later on add gradients + // to the output arrays. + iter_req[i] = iter_num == len - 1 ? req[i]: kAddTo; + } + + state.Backward(iter_num, ograds, iter_req, igrads); + + // We need to wait for the iteration to complete before executing + // the next one or return from the loop. In this way, we can reuse + // the memory in the subgraph. + for (size_t i = 0; i < igrads.size(); i++) { + igrads[i].WaitToRead(); + } + + size_t num_states = ograds.size() - num_output_data; + for (size_t i = 0; i < num_states; i++) { + size_t loc = params.in_state_locs[i]; + CHECK_LT(loc, igrads.size()); + ograds[i + num_output_data] = igrads[loc]; + } + } + state.Cleanup(); +} + +static bool ForeachShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + nnvm::ShapeVector shape_inputs = *in_shape; + // foreach iterates over the first input NDArray over the first dimension. + size_t loc0 = params.in_data_locs[0]; + size_t len = in_shape->at(loc0)[0]; + for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + size_t loc = params.in_data_locs[i]; + CHECK_EQ(len, in_shape->at(loc)[0]); + shape_inputs[loc] = TShape(in_shape->at(loc).begin() + 1, in_shape->at(loc).end()); + } + CHECK_EQ(attrs.subgraphs.size(), 1U); + nnvm::Graph g; + g.outputs = attrs.subgraphs[0]->outputs; + const auto& idx = g.indexed_graph(); + CHECK_EQ(idx.input_nodes().size(), in_shape->size()); + CHECK_EQ(idx.outputs().size(), out_shape->size()); + imperative::CheckAndInferShape(&g, std::move(shape_inputs), true); + + const auto& shapes = g.GetAttr("shape"); + // Inferring the shape in the subgraph may infer the shape of the inputs. + // We need to copy the inferred input shapes back. + const auto &input_nids = idx.input_nodes(); + CHECK_EQ(input_nids.size(), in_shape->size()); + for (size_t i = 0; i < in_shape->size(); i++) { + auto eid = idx.entry_id(input_nids[i], 0); + // If the input shape is none, we should update them. + if ((*in_shape)[i].ndim() == 0 || (*in_shape)[i].Size() == 0) + SHAPE_ASSIGN_CHECK(*in_shape, i, shapes[eid]); + } + + // For the shape of output data. + for (int i = 0; i < params.num_out_data; i++) { + uint32_t eid = idx.entry_id(g.outputs[i]); + const auto& g_out_shape = shapes[eid]; + auto out = TShape(g_out_shape.ndim() + 1); + out[0] = len; + for (size_t i = 1; i < out.ndim(); i++) + out[i] = g_out_shape[i - 1]; + SHAPE_ASSIGN_CHECK(*out_shape, i, out); + } + + // For the remaining shapes. + for (size_t i = params.num_out_data; i < g.outputs.size(); i++) { + uint32_t eid = idx.entry_id(g.outputs[i]); + SHAPE_ASSIGN_CHECK(*out_shape, i, shapes[eid]); + } + size_t num_states = g.outputs.size() - params.num_out_data; + for (size_t i = 0; i < num_states; i++) { + size_t loc = params.in_state_locs[i]; + CHECK((*out_shape)[i + params.num_out_data] == (*in_shape)[loc]); + } + return true; +} + +static bool ForeachType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 1U); + return InferSubgraphDataType(*attrs.subgraphs[0], in_type, out_type); +} + +static bool ForeachStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 1U); + return InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, + dispatch_mode, in_attrs, out_attrs); +} + +static bool BackwardForeachStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size(), (size_t) params.num_args - 1); + CHECK_EQ(attrs.subgraphs.size(), 1U); + return InferSubgraphBackwardStorage(*attrs.subgraphs[0], dev_mask, + dispatch_mode, in_attrs, out_attrs); +} + +static OpStatePtr CreateForeachState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const ForeachParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create(*attrs.subgraphs[0], params); +} + +static std::vector +ForeachGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_foreach"}; + std::vector entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + +NNVM_REGISTER_OP(_foreach) +.MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", ForeachStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.push_back("fn"); + for (int i = 0; i < params.num_args - 1; i++) + names.push_back("data" + std::to_string(i)); + return names; +}) +.set_attr("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector{0}; +}) +.set_attr("FGradient", ForeachGradient) +.set_attr("FCreateOpState", CreateForeachState) +.set_attr("FInferShape", ForeachShape) +.set_attr("FInferType", ForeachType) +.set_attr("FStatefulComputeEx", ForeachComputeExCPU) +// Foreach operator works like an executor. Its code will always run on CPU. +// So the same code can be registered for both CPU and GPU. +.set_attr("FStatefulComputeEx", ForeachComputeExCPU) +.set_attr("key_var_num_args", "num_args") +.add_argument("fn", "Symbol", "Input graph.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(ForeachParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_foreach) +.set_num_inputs([](const NodeAttrs& attrs){ + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 1; + }) +.set_num_outputs([](const NodeAttrs& attrs){ + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_args - 1; + }) +.set_attr("FInferStorageType", BackwardForeachStorageType) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulComputeEx", ForeachGradComputeExCPU) +.set_attr("FStatefulComputeEx", ForeachGradComputeExCPU); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc new file mode 100644 index 000000000000..fa22898c13d4 --- /dev/null +++ b/src/operator/subgraph_op_common.cc @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./subgraph_op_common.h" +#include "./operator_common.h" +#include "../imperative/imperative_utils.h" + +namespace mxnet { +namespace op { + +bool InferSubgraphDataType(const nnvm::Symbol &subgraph, + std::vector *in_type, + std::vector *out_type) { + nnvm::DTypeVector dtype_inputs = *in_type; + nnvm::Graph g; + g.outputs = subgraph.outputs; + const auto& idx = g.indexed_graph(); + CHECK_EQ(idx.input_nodes().size(), in_type->size()); + CHECK_EQ(idx.outputs().size(), out_type->size()); + imperative::CheckAndInferType(&g, std::move(dtype_inputs), true); + + const auto &dtypes = g.GetAttr("dtype"); + + // Inferring the data type in the subgraph may infer the data type of the inputs. + // We need to copy the inferred input data types back. + const auto &input_nids = idx.input_nodes(); + CHECK_EQ(input_nids.size(), in_type->size()); + for (size_t i = 0; i < in_type->size(); i++) { + auto eid = idx.entry_id(input_nids[i], 0); + TYPE_ASSIGN_CHECK(*in_type, i, dtypes[eid]); + } + + for (size_t i = 0; i < g.outputs.size(); i++) + TYPE_ASSIGN_CHECK(*out_type, i, dtypes[idx.entry_id(g.outputs[i])]); + return true; +} + +bool InferSubgraphStorage(const nnvm::Symbol &subgraph, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + nnvm::Graph g; + g.outputs = subgraph.outputs; + const auto& idx = g.indexed_graph(); + CHECK_EQ(idx.input_nodes().size(), in_attrs->size()); + CHECK_EQ(idx.outputs().size(), out_attrs->size()); + exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); + StorageTypeVector storage_type_inputs = *in_attrs; + imperative::CheckAndInferStorageType(&g, std::move(dev_masks), + std::move(storage_type_inputs), true); + + const auto& stypes = g.GetAttr("storage_type"); + + // Inferring the storage in the subgraph may infer the storage of the inputs. + // We need to copy the inferred input storage back. + const auto &input_nids = idx.input_nodes(); + CHECK_EQ(input_nids.size(), in_attrs->size()); + for (size_t i = 0; i < in_attrs->size(); i++) { + auto eid = idx.entry_id(input_nids[i], 0); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, stypes[eid]); + } + + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + auto &outputs = idx.outputs(); + CHECK(outputs.size() == out_attrs->size()); + for (size_t i = 0; i < out_attrs->size(); i++) + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, stypes[idx.entry_id(outputs[i])]); + return true; +} + +bool InferSubgraphBackwardStorage(const nnvm::Symbol &subgraph, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace nnvm; + // construct backward graph + nnvm::Graph grad_graph; + nnvm::Graph fwd_graph; + std::vector potential_nodes; + { + fwd_graph.outputs = subgraph.outputs; + std::vector ograd_entries; + ograd_entries.reserve(fwd_graph.outputs.size()); + for (size_t i = 0; i < fwd_graph.outputs.size(); ++i) { + ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0}); + } + + std::vector xs; + std::vector args = subgraph.ListInputs(nnvm::Symbol::kReadOnlyArgs); + xs.reserve(args.size()); + for (const auto& i : args) + xs.emplace_back(NodeEntry{i, 0, 0}); + CHECK_GT(xs.size(), 0) + << "There are no inputs in computation graph that require gradients."; + + static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; + grad_graph = pass::Gradient( + fwd_graph, fwd_graph.outputs, xs, ograd_entries, + exec::AggregateGradient, nullptr, nullptr, + zero_ops, "_copy"); + potential_nodes.reserve(fwd_graph.outputs.size() + xs.size() + ograd_entries.size()); + for (auto e : ograd_entries) + potential_nodes.push_back(e.node.get()); + for (auto e : xs) + potential_nodes.push_back(e.node.get()); + for (auto e : fwd_graph.outputs) + potential_nodes.push_back(e.node.get()); + } + + const auto& idx = grad_graph.indexed_graph(); + auto input_nodes = idx.input_nodes(); + StorageTypeVector storage_type_inputs(input_nodes.size()); + for (size_t i = 0; i < input_nodes.size(); i++) { + auto node_id = input_nodes[i]; + const nnvm::IndexedGraph::Node &n = idx[node_id]; + auto it = std::find(potential_nodes.begin(), potential_nodes.end(), n.source); + CHECK(it != potential_nodes.end()); + size_t idx = it - potential_nodes.begin(); + CHECK_LT(idx, in_attrs->size()); + storage_type_inputs[i] = in_attrs->at(idx); + } + CHECK_EQ(idx.outputs().size(), out_attrs->size()); + exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); + imperative::CheckAndInferStorageType(&grad_graph, std::move(dev_masks), + std::move(storage_type_inputs), true); + + const auto& stypes = grad_graph.GetAttr("storage_type"); + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + auto &outputs = idx.outputs(); + CHECK(outputs.size() == out_attrs->size()); + for (size_t i = 0; i < out_attrs->size(); i++) + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, stypes[idx.entry_id(outputs[i])]); + return true; +} + +void LoopState::Forward(std::vector cinputs, + const std::vector& req, + std::vector coutputs, + bool is_recording) { + using namespace nnvm; + using namespace imperative; + + bool orig_is_record; + if (is_recording) + orig_is_record = Imperative::Get()->set_is_recording(true); + else + orig_is_record = Imperative::Get()->is_recording(); + + std::vector inputs(cinputs.size()); + std::vector outputs(coutputs.size()); + for (size_t i = 0; i < inputs.size(); i++) + inputs[i] = &cinputs[i]; + for (size_t i = 0; i < outputs.size(); i++) + outputs[i] = &coutputs[i]; + + if (is_recording) { + all_inputs.push_back(cinputs); + std::vector gradients(cinputs.size()); + std::vector input_ptrs(cinputs.size()); + std::vector gradient_ptrs(cinputs.size()); + std::vector grad_reqs(cinputs.size()); + for (size_t i = 0; i < gradients.size(); i++) { + gradients[i] = NDArray(cinputs[i].shape(), cinputs[i].ctx(), + true, cinputs[i].dtype()); + input_ptrs[i] = &cinputs[i]; + gradient_ptrs[i] = &gradients[i]; + grad_reqs[i] = kWriteTo; + } + Imperative::Get()->MarkVariables(input_ptrs, grad_reqs, gradient_ptrs);; + } + + std::vector > kwargs; + kwargs.push_back(std::pair("inline_limit", "0")); + // Get input names. + const auto& idx = subgraph.indexed_graph(); + std::vector arg_names(idx.input_nodes().size()); + for (size_t i = 0; i < idx.input_nodes().size(); ++i) + arg_names[i] = idx[idx.input_nodes()[i]].source->attrs.name; + // We don't have parameters for the cached op. + std::unordered_map > params; + CachedOpPtr op = std::make_shared(subgraph_sym, kwargs, + arg_names, params); + // TODO(zhengda) we need to avoid shape inference and memory plan whenever the op is + // called. Currently, CachedOp allocates memory each time Forward is called. + // I need to fix this once the PR for static memory allocation in CachedOp is + // merged. https://github.com/apache/incubator-mxnet/pull/10817 + op->Forward(nullptr, inputs, outputs); + + if (is_recording) { + all_outputs.push_back(coutputs); + iter_ops.push_back(op); + } + + Imperative::Get()->set_is_recording(orig_is_record); +} + +void LoopState::Backward(int iter_no, + std::vector ograds, + const std::vector &req, + std::vector igrads) { + using namespace nnvm; + using namespace imperative; + + CHECK_GT(iter_ops.size(), iter_no) + << "We didn't record the computation for iteration " << iter_no; + auto op = iter_ops[iter_no]; + std::vector inputs; + std::vector outputs; + inputs.reserve(op->num_backward_inputs()); + outputs.reserve(op->num_inputs()); + for (size_t i = 0; i < ograds.size(); i++) + inputs.push_back(&ograds[i]); + + const std::vector &save_inputs = op->save_inputs(); + const std::vector &save_outputs = op->save_outputs(); + CHECK_EQ(save_inputs.size(), all_inputs[iter_no].size()); + CHECK_EQ(op->num_outputs(), all_outputs[iter_no].size()); + for (size_t i = 0; i < all_inputs[iter_no].size(); i++) { + if (save_inputs[i]) + inputs.push_back(&all_inputs[iter_no][i]); + } + for (size_t i = 0; i < all_outputs[iter_no].size(); i++) { + if (save_outputs[i]) + inputs.push_back(&all_outputs[iter_no][i]); + } + CHECK_EQ(inputs.size(), op->num_backward_inputs()); + for (size_t i = 0; i < igrads.size(); i++) + outputs.push_back(&igrads[i]); + CHECK_EQ(outputs.size(), op->num_inputs()); + + CHECK(!Imperative::AGInfo::IsNone(all_outputs[iter_no][0])); + const nnvm::NodeEntry &node_entry = all_outputs[iter_no][0].entry(); + OpStatePtr state = Imperative::AGInfo::Get(node_entry.node).state; + op->Backward(false, state, inputs, req, outputs); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h new file mode 100644 index 000000000000..74e7cb2d1ccd --- /dev/null +++ b/src/operator/subgraph_op_common.h @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_ +#define MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_ + +#include +#include +#include +#include +#include "../imperative/imperative_utils.h" + +namespace mxnet { +namespace op { + +/* + * Infer the data types of inputs and outputs of an operator that contains a + * subgraph. + */ +bool InferSubgraphDataType(const nnvm::Symbol &subgraph, std::vector *in_type, + std::vector *out_type); + +/* + * Infer the storage types of inputs and outputs of an operator that contains a + * subgraph. + */ +bool InferSubgraphStorage(const nnvm::Symbol &subgraph, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs); + +/* + * Infer the storage types of inputs and outputs of the backward computation of + * an operator that contains a subgraph. + */ +bool InferSubgraphBackwardStorage(const nnvm::Symbol &subgraph, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs); + +/* + * This contains the states for running a loop and provides methods + * of running the subgraph computation for an iteration. + */ +class LoopState { + // These are output arrays from all iterations. + // They also contain the Op state for each CachedOp. + std::vector > all_outputs; + std::vector > all_inputs; + std::vector > all_gradients; + std::vector iter_ops; + Symbol subgraph_sym; + nnvm::Graph subgraph; + + public: + LoopState(const Symbol &g) { + this->subgraph_sym = g; + this->subgraph.outputs = g.outputs; + } + + void Forward(std::vector cinputs, + const std::vector& req, + std::vector coutputs, + bool is_recording); + void Backward(int iter_no, + std::vector ograds, + const std::vector &req, + std::vector igrads); + void Cleanup() { + all_outputs.clear(); + all_inputs.clear(); + all_gradients.clear(); + iter_ops.clear(); + } +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_ diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 24d5a932d7b2..d4ac88900c5d 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -18,9 +18,10 @@ import mxnet as mx from mxnet import gluon import numpy as np +import copy from numpy.testing import assert_allclose import unittest -from mxnet.test_utils import almost_equal +from mxnet.test_utils import almost_equal, assert_almost_equal def test_rnn(): @@ -28,13 +29,62 @@ def test_rnn(): inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] + assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', + 'rnn_i2h_bias', 'rnn_i2h_weight'] assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, hidden_size, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = gluon.rnn.RNNCell(hidden_size, prefix='rnn_') + + def hybrid_forward(self, F, inputs, states): + states = [states] + out, states = F.contrib.foreach(self.cell, inputs, states) + return out + +def test_contrib_rnn(): + batch_size = 10 + hidden_size = 100 + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50)) + states = mx.nd.normal(loc=0, scale=1, shape=(batch_size, hidden_size)) + layer = TestRNNLayer(hidden_size) + layer.initialize(ctx=mx.cpu(0)) + res1 = layer(rnn_data, states) + params1 = layer.collect_params() + orig_params1 = copy.deepcopy(params1) + + trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res1 = layer(rnn_data, states) + res1.backward() + trainer.step(batch_size) + + layer = TestRNNLayer(hidden_size) + layer.initialize(ctx=mx.cpu(0)) + layer.hybridize() + res2 = layer(rnn_data, states) + params2 = layer.collect_params() + for key, val in orig_params1.items(): + params2[key].set_data(val.data()) + + trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res2 = layer(rnn_data, states) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + res2.backward() + trainer.step(batch_size) + + for key, val in params1.items(): + weight1 = val.data() + weight2 = params2[key].data() + assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), rtol=0.001, atol=0.0001) + + def test_lstm(): cell = gluon.rnn.LSTMCell(100, prefix='rnn_') inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e7976e01f9d8..2b2a66725bb1 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -24,7 +24,7 @@ import itertools from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * -from mxnet.base import py_str, MXNetError +from mxnet.base import py_str, MXNetError, _as_list from common import setup_module, with_seed import unittest @@ -5663,6 +5663,313 @@ def test_float16_min_max(): assert np.finfo('float16').max == mx.nd.max(a).asscalar() +@with_seed() +def test_foreach(): + v3 = mx.sym.var("v0") + v4 = mx.sym.var("v1") + v5 = mx.sym.var("v2") + v6 = mx.sym.var("v3") + v7 = mx.sym.var("v4") + + # This tests foreach with accumulation sum. + def step1(in1, states, free): + out = in1 * 2 + states[0] + free[0] + return (out, [out]) + def step2(in1, states, free): + out = states[0] + in1 * 2 + free[0] + return (out, [out]) + def step3(in1, states, free): + out = in1[0] + in1[1] + states[0] + states[1] + free[0] + return ([out, out * 2], [out * 2, out * 3]) + + def verify_foreach(step, in_syms, state_syms, free_syms, + in_arrs, init_states, frees, out_grads, is_train=True, + free_vars_func=None): + step_sym = lambda in_syms, state_syms : step(in_syms, state_syms, free_syms) + res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms) + out = _as_list(res) + for i in range(len(out)): + out[i] = out[i] * 2 + out.extend(states) + out = mx.sym.Group(out) + arr_grads = [] + arg_dict = {} + arg_grad_dict = {} + i = 0 + for arr in _as_list(in_arrs): + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + for arr in init_states: + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + for arr in frees: + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + + gin_order = [] + for name in out.list_inputs(): + name = name[1:] + gin_order.append(int(name)) + + e = out.bind(ctx=default_context(), args=arg_dict, args_grad=arg_grad_dict) + e.forward(is_train=is_train) + if (is_train): + # backward + tmp_grads = out_grads[0][:] + tmp_grads.extend(out_grads[1]) + e.backward(tmp_grads) + + # Below we use imperative to reimplement foreach and compute its gradients. + res = [] + for i in range(len(_as_list(out_grads[0]))): + res.append([]) + for arr in _as_list(in_arrs): + arr.attach_grad() + for arr in init_states: + arr.attach_grad() + for arr in frees: + arr.attach_grad() + with mx.autograd.record(): + frees_imp = frees if free_vars_func is None else free_vars_func(frees) + step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs, frees_imp) + states = [mx.nd.expand_dims(s, 0) for s in init_states] + res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states) + + res2 = _as_list(res) + for i in range(len(res2)): + res2[i] = res2[i] * 2 + if isinstance(states, list): + states = [mx.nd.expand_dims(s, 0) for s in states] + res2.extend(states) + else: + states = mx.nd.expand_dims(states, 0) + res2.append(states) + res = mx.nd.concat(*res2, dim=0) + + tmp_grads = out_grads[0][:] + tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]] + tmp_grads.extend(tmp_grads1) + if (is_train): + res.backward(mx.nd.concat(*tmp_grads, dim=0)) + for i in range(len(res2)): + assert_almost_equal(e.outputs[i].asnumpy(), res2[i].asnumpy(), + rtol=0.001, atol=0.0001) + if (is_train): + all_ins = _as_list(in_arrs)[:] + all_ins.extend(init_states) + all_ins.extend(frees) + for i in range(len(all_ins)): + assert_almost_equal(all_ins[i].grad.asnumpy(), + e.grad_arrays[gin_order[i]].asnumpy()) + + # Test cases: + # * graph inputs are stored in different orders. + # This is to test if foreach finds the data arrays and weight arrays + # in the right location. + # * the number of iterations: odd or even. + # * multiple inputs and multiple outputs. + # * inference. + + #states = [mx.nd.random.uniform(shape=(2))] + + #frees1 = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + #arrs = mx.nd.random.uniform(shape=(3, 2)) + states = [mx.nd.arange(2)] + + frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1] + arrs = mx.nd.arange(6).reshape(shape=(3, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, + lambda frees : [frees[0] + frees[1]]) + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, + lambda frees : [frees[0] + frees[1]]) + + frees = [mx.nd.random.uniform(shape=(2))] + arrs = mx.nd.random.uniform(shape=(2, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + arrs = mx.nd.random.uniform(shape=(3, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + arrs = mx.nd.random.uniform(shape=(2, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + arrs = mx.nd.random.uniform(shape=(3, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) + verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) + + +@with_seed() +def test_foreach_nested(): + # Test nested foreach. + def step_in(in1, states): + out = in1 * 2 + states[0] + return (out, [out]) + + def step(in1, states): + out1 = mx.sym.contrib.foreach(step_in, in1, states) + out = mx.sym.broadcast_add(out1[0], states[0]) + return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1, 2)))]) + + data_sym = mx.sym.var("v1") + state_sym = mx.sym.var("v2") + out = mx.sym.contrib.foreach(step, data_sym, [state_sym]) + + out1 = _as_list(out[0]) + for i in range(len(out1)): + out1[i] = out1[i] + out1.extend(out[1]) + out = mx.sym.Group(out1) + + data = mx.nd.arange(4).reshape((1, 2, 2)) + state = mx.nd.arange(2) + data_grad = mx.nd.empty(data.shape) + state_grad = mx.nd.empty(state.shape) + e = out.bind(ctx=default_context(), args={'v1':data, 'v2':state}, + args_grad={'v1':data_grad, 'v2':state_grad}) + e.forward(is_train=True) + out = mx.nd.zeros_like(data) + for i in range(data.shape[0]): + data1 = data[i] + out1 = mx.nd.zeros_like(data1) + for j in range(data1.shape[0]): + if (j > 0): + out1[j] = out1[j-1] + data1[j] * 2 + else: + out1[j] = data1[j] * 2 + state + if (i > 0): + state = mx.nd.squeeze(mx.nd.slice(out[i-1], begin=(0, 0), end=(1, 2))) + out[i] = mx.nd.broadcast_add(out1, state) + else: + out[i] = mx.nd.broadcast_add(out1, state) + out = out + assert_almost_equal(out.asnumpy(), e.outputs[0].asnumpy(), rtol=0.001, atol=0.0001) + + +@with_seed() +def test_foreach_lstm(): + data = mx.sym.var("data") + init_h = mx.sym.var("h") + init_c = mx.sym.var("c") + i2h_weight = mx.sym.var("i2h_weight") + h2h_weight = mx.sym.var("h2h_weight") + i2h_bias = mx.sym.var("i2h_bias") + h2h_bias = mx.sym.var("h2h_bias") + + # This tests foreach with accumulation sum. + def step(in1, states): + params = mx.rnn.RNNParams() + params._params['i2h_weight'] = i2h_weight + params._params['h2h_weight'] = h2h_weight + params._params['i2h_bias'] = i2h_bias + params._params['h2h_bias'] = h2h_bias + lstm = mx.rnn.LSTMCell(4, prefix='mylstm_', params=params) + next_h, [next_h, next_c] = lstm(in1, states) + # TODO This is problematic. We can't count on the user to define two different symbols. + return (next_h, [next_h, next_c]) + + def sym_group(out): + if (isinstance(out[0], mx.sym.Symbol)): + ret = [out[0]] + else: + ret = out[0] + ret.extend(out[1]) + return mx.sym.Group(ret) + + data_arr = mx.nd.random.uniform(shape=(2, 2, 4)) + h_arr = mx.nd.random.uniform(shape=(2, 4)) + c_arr = mx.nd.random.uniform(shape=(2, 4)) + i2h_warr = mx.nd.random.uniform(shape=(16, 4)) + h2h_warr = mx.nd.random.uniform(shape=(16, 4)) + i2h_barr = mx.nd.random.uniform(shape=(16)) + h2h_barr = mx.nd.random.uniform(shape=(16)) + + data_arr_grad1 = mx.nd.empty(data_arr.shape) + h_arr_grad1 = mx.nd.empty(h_arr.shape) + c_arr_grad1 = mx.nd.empty(c_arr.shape) + i2h_warr_grad1 = mx.nd.empty(i2h_warr.shape) + h2h_warr_grad1 = mx.nd.empty(h2h_warr.shape) + i2h_barr_grad1 = mx.nd.empty(i2h_barr.shape) + h2h_barr_grad1 = mx.nd.empty(h2h_barr.shape) + out = mx.sym.contrib.foreach(step, data, [init_h, init_c]) + out = sym_group(out) + e1 = out.bind(ctx=default_context(), + args={'data': data_arr, 'h': h_arr, 'c': c_arr, + 'i2h_weight': i2h_warr, 'h2h_weight': h2h_warr, + 'i2h_bias': i2h_barr, 'h2h_bias': h2h_barr}, + args_grad={'data': data_arr_grad1, 'h': h_arr_grad1, 'c': c_arr_grad1, + 'i2h_weight': i2h_warr_grad1, 'h2h_weight': h2h_warr_grad1, + 'i2h_bias': i2h_barr_grad1, 'h2h_bias': h2h_barr_grad1}) + e1.forward(is_train=True) + outputs1 = e1.outputs + # backward + out_grads = [] + for arr in e1.outputs: + out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape)) + e1.backward(out_grads) + + data_arr_grad2 = mx.nd.empty(data_arr.shape) + h_arr_grad2 = mx.nd.empty(h_arr.shape) + c_arr_grad2 = mx.nd.empty(c_arr.shape) + i2h_warr_grad2 = mx.nd.empty(i2h_warr.shape) + h2h_warr_grad2 = mx.nd.empty(h2h_warr.shape) + i2h_barr_grad2 = mx.nd.empty(i2h_barr.shape) + h2h_barr_grad2 = mx.nd.empty(h2h_barr.shape) + lstm = mx.rnn.LSTMCell(4, prefix='mylstm_') + h = init_h + c = init_c + unroll_outs = [] + for inputs in mx.sym.split(data, num_outputs=data_arr.shape[0], axis=0, squeeze_axis=True): + h, [h, c] = lstm(inputs, [h, c]) + unroll_outs.append(mx.sym.expand_dims(h, axis=0)) + unroll_outs = mx.sym.concat(*unroll_outs, dim=0) + out = mx.sym.Group([unroll_outs, h, c]) + e2 = out.bind(ctx=default_context(), + args={'data': data_arr, 'h': h_arr, 'c': c_arr, + 'mylstm_i2h_weight': i2h_warr, 'mylstm_h2h_weight': h2h_warr, + 'mylstm_i2h_bias': i2h_barr, 'mylstm_h2h_bias': h2h_barr}, + args_grad={'data': data_arr_grad2, 'h': h_arr_grad2, 'c': c_arr_grad2, + 'mylstm_i2h_weight': i2h_warr_grad2, 'mylstm_h2h_weight': h2h_warr_grad2, + 'mylstm_i2h_bias': i2h_barr_grad2, 'mylstm_h2h_bias': h2h_barr_grad2}) + e2.forward(is_train=True) + outputs2 = e2.outputs + e2.backward(out_grads) + + for i in range(len(outputs2)): + assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), rtol=0.001, atol=0.0001) + for i in range(len(e1.grad_arrays)): + assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy()) + + @with_seed() def test_squeeze_op(): def check_squeeze_op(shape, axis=None):