diff --git a/CMakeLists.txt b/CMakeLists.txt index 976c736f5f35..7667c87e5b4e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,7 +143,7 @@ else(MSVC) add_definitions(-DMSHADOW_USE_F16C=0) endif() set(CMAKE_POSITION_INDEPENDENT_CODE ON) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wno-unknown-pragmas -Wno-sign-compare") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wno-unknown-pragmas -Wno-sign-compare -Werror=return-type") if ("${CMAKE_CXX_COMPILER_ID}" MATCHES ".*Clang$") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-braced-scalar-init") endif() diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d4f756f5333c..c699d7825cee 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1224,7 +1224,7 @@ MXNET_DLL int MXAutogradBackward(mx_uint num_output, * \param output_handles output NDArrays * \param ograd_handles head gradient for NDArrays * \param num_variables number of variables - * \param + * \param var_handles variables to compute gradient with respect to (d / d var) * \param retain_graph whether to keep the graph after backward * \param is_train whether to do backward for training or inference * \return 0 when success, -1 when failure happens diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index a86cc085a34b..2976571a63cf 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -35,48 +35,53 @@ #include "./ndarray.h" namespace mxnet { -/*! \brief runtime functions for NDArray */ -class Imperative { +/*! + * Autograd Info used in class: nnvm::Node::info + */ +class AGInfo { public: - /*! \brief */ - class AGInfo { - public: - Context ctx; - OpReqType grad_req; - OpStatePtr state; - std::vector outputs; - std::vector out_grads; - bool fresh_out_grad; + Context ctx; + OpReqType grad_req; + OpStatePtr state; + std::vector outputs; + std::vector out_grads; + bool fresh_out_grad; - AGInfo() : + AGInfo() : grad_req(kNullOp), fresh_out_grad(false) {} - static void Clear(const nnvm::NodePtr& node) { - if (node == nullptr || node->info.empty()) return; - AGInfo& info = Get(node); - if (info.grad_req != kNullOp) return; - node->info.clear(); - } + static void Clear(const nnvm::NodePtr& node) { + if (node == nullptr || node->info.empty()) return; + AGInfo& info = Get(node); + if (info.grad_req != kNullOp) return; + node->info.clear(); + } - static AGInfo& Get(const nnvm::NodePtr& node) { - return dmlc::get(node->info); - } + static AGInfo& Get(const nnvm::NodePtr& node) { + return dmlc::get(node->info); + } - static AGInfo& Create(const nnvm::NodePtr& node) { - node->info.construct(); - return Get(node); - } + static AGInfo& Create(const nnvm::NodePtr& node) { + node->info.construct(); + return Get(node); + } - static bool IsNone(const NDArray& arr) { - return arr.entry_.node == nullptr || arr.entry_.node->info.empty(); - } + static bool IsNone(const NDArray& arr) { + return arr.autograd_.node == nullptr || arr.autograd_.node->info.empty(); + } + + static bool IsVariable(const nnvm::NodePtr& node) { + AGInfo& info = Get(node); + return info.grad_req != kNullOp && info.outputs.size() == 1 + && info.out_grads.size() == 1; + } +}; + +/*! \brief runtime functions for NDArray */ +class Imperative { + public: + /*! \brief */ - static bool IsVariable(const nnvm::NodePtr& node) { - AGInfo& info = Get(node); - return info.grad_req != kNullOp && info.outputs.size() == 1 - && info.out_grads.size() == 1; - } - }; /*! \brief whether operator recording is on. */ bool is_training() const { return is_train_; @@ -97,11 +102,11 @@ class Imperative { is_recording_ = is_recording; return old; } - /*! brief whether numpy compatibility is on. */ + /*! \brief whether numpy compatibility is on. */ bool is_np_shape() const { return is_np_shape_; } - /*! brief turn on or turn off numpy compatibility switch. */ + /*! \brief turn on or turn off numpy compatibility switch. */ bool set_is_np_shape(bool is_np_shape) { bool old = is_np_shape_; is_np_shape_ = is_np_shape; @@ -160,7 +165,29 @@ class Imperative { private: friend class NDArray; - /*! \brief make constructor protected. */ + /*! Create a forward graph + * @param output_nodes graph node vector to add nodes to + * @param outputs source ndarrays + * @return vector of nodes + */ + static nnvm::Graph CreateGraph(const std::vector &outputs); + /*! Create gradient nodes using output shapes and ctx. + * Gradient heads are initialized to 1 if they are not present (nullptr) + * @return vector of nodes + */ + static std::vector CreateHeadGradientNodes(const std::vector& outputs, + const std::vector& ograds); + + struct GradientVariableNodes; + /*! Create variable nodes. + * If variables is provided, gradient nodes are crated for them. Otherwise it uses read only + * inputs reachable from the outputs. + * @param variables + * @param outputs + * @return aux data structure with nodes and arrays for gradients + */ + GradientVariableNodes CreateGradientVariableNodes(const std::vector& variables, + const std::vector& outputs); Imperative() { if (PreferBulkExecTrain()) backward_bulk_size_ = BulkExecMaxNodeTrainBwd(); @@ -168,9 +195,9 @@ class Imperative { /*! \brief find the input/output ndarrays that are needed for backward */ void GetBackwardDependency( const nnvm::NodePtr& node, - uint32_t num_inputs, uint32_t num_outputs, - std::vector *p_save_inputs, - std::vector *p_save_outputs); + size_t num_inputs, size_t num_outputs, + std::vector *save_inputs, + std::vector *save_outputs); /*! \brief indicate whether is training. */ #if DMLC_CXX11_THREAD_LOCAL static thread_local bool is_train_; diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 176aa0aaa197..56273013ac66 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -80,10 +80,12 @@ class MKLDNNMemory; * \brief ndarray interface */ class NDArray { + friend class AGInfo; + friend class Imperative; public: /*! \brief default constructor */ NDArray() - : entry_(nullptr) { + : autograd_(nullptr) { } /*! * \brief constructs a new dynamic NDArray @@ -98,7 +100,7 @@ class NDArray { shape_(shape), dtype_(dtype), storage_type_(kDefaultStorage), - entry_(nullptr) { + autograd_(nullptr) { } /*! \brief constructor for NDArray with storage type */ @@ -117,7 +119,7 @@ class NDArray { shape_(), dtype_(dtype), storage_type_(kDefaultStorage), - entry_(nullptr) { + autograd_(nullptr) { } /*! * \brief constructing a static NDArray that shares data with TBlob @@ -131,7 +133,7 @@ class NDArray { shape_(data.shape_), dtype_(data.type_flag_), storage_type_(kDefaultStorage), - entry_(nullptr) { + autograd_(nullptr) { } /*! @@ -149,7 +151,7 @@ class NDArray { }), shape_(data.shape_), dtype_(data.type_flag_), storage_type_(kDefaultStorage), - entry_(nullptr) { + autograd_(nullptr) { } /*! \brief create ndarray from shared memory */ @@ -158,7 +160,7 @@ class NDArray { shape_(shape), dtype_(dtype), storage_type_(kDefaultStorage), - entry_(nullptr) { + autograd_(nullptr) { } /*! @@ -177,7 +179,7 @@ class NDArray { shape_(shape), dtype_(data.type_flag_), storage_type_(stype), - entry_(nullptr) { + autograd_(nullptr) { } /*! * \brief initialize the NDArray, assuming it is not assigned a meaningful shape before @@ -387,6 +389,7 @@ class NDArray { } /*! \return the associated variable of the ndarray.*/ inline Engine::VarHandle var() const { + CHECK(ptr_); return ptr_->var; } /*! \return byte offset in chunk of the ndarray*/ @@ -395,6 +398,7 @@ class NDArray { } /*! \brief return var version of the NDArray*/ inline size_t version() const { + CHECK(var()); return var()->version(); } /*! @@ -649,7 +653,7 @@ class NDArray { */ NDArray Detach() const { NDArray ret(*this); - ret.entry_ = nnvm::NodeEntry(nullptr); + ret.autograd_ = nnvm::NodeEntry(nullptr); return ret; } @@ -812,7 +816,6 @@ class NDArray { std::vector* keys); private: - friend class Imperative; /*! \brief the real data chunk that backs NDArray */ // shandle is used to store the actual values in the NDArray // aux_handles store the aux data(such as indices) if it's needed by non-default storage. @@ -1102,7 +1105,7 @@ class NDArray { /*! \brief storage type of data */ NDArrayStorageType storage_type_ = kUndefinedStorage; /*! \brief node entry for autograd */ - nnvm::NodeEntry entry_; + nnvm::NodeEntry autograd_; /*! * \brief internal TBlob * \note When user access tblob_ by some const methods like diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 889b5028a460..ad061c993334 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -54,7 +54,8 @@ enum OpReqType { */ kWriteInplace, /*! \brief add to the provided space */ - kAddTo + kAddTo, + kOpReqTypeMax }; /*! diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 4546659ca64e..47529cb6f078 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -355,7 +355,7 @@ int MXAutogradBackwardEx(mx_uint num_output, } auto grads = Imperative::Get()->Backward(outputs, ograds, variables, is_train, - retain_graph, create_graph); + retain_graph, create_graph); if (num_variables != 0) { ret->ret_handles.clear(); ret->out_types.clear(); diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 42d03e55a275..7d69659767a0 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -102,6 +102,7 @@ RunContext StreamManager::GetRunContext( #endif // MXNET_USE_CUDA default: LOG(FATAL) << "Not Reached"; + break; } } return ret; diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index f544d6ba3392..745bbf948576 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -258,6 +258,7 @@ inline Graph MXGradient( if (copy_op_str != std::string()) { graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); } + /// @sa nnvm::pass::Gradient in gradient.cc return ApplyPass(std::move(graph), "MXGradient"); } } // namespace pass diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index e6a177e27847..fd80d72f99e5 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -124,66 +124,64 @@ void Imperative::MarkVariables( const std::vector& variables, const std::vector& grad_reqs, const std::vector& gradients) { - for (uint32_t i = 0; i < variables.size(); ++i) { + for (size_t i = 0; i < variables.size(); ++i) { std::string str_c(std::to_string(variable_count_++)); - - variables[i]->entry_ = nnvm::NodeEntry{ + // Add autograd storage for variables and link to the graph + variables[i]->autograd_ = nnvm::NodeEntry{ nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0}; - AGInfo& info = AGInfo::Create(variables[i]->entry_.node); + AGInfo &info = AGInfo::Create(variables[i]->autograd_.node); info.outputs.emplace_back(variables[i]->Detach()); info.out_grads.emplace_back(gradients[i]->Detach()); info.grad_req = static_cast(grad_reqs[i]); + CHECK(info.grad_req < kOpReqTypeMax) << "gradient update request out of range"; info.ctx = variables[i]->ctx(); - - gradients[i]->entry_ = nnvm::NodeEntry{ + // Handle gradients themselves + gradients[i]->autograd_ = nnvm::NodeEntry{ nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0}; - AGInfo& grad_info = AGInfo::Create(gradients[i]->entry_.node); + AGInfo &grad_info = AGInfo::Create(gradients[i]->autograd_.node); grad_info.outputs.emplace_back(gradients[i]->Detach()); grad_info.ctx = gradients[i]->ctx(); } } - void Imperative::GetBackwardDependency( const nnvm::NodePtr& node, - uint32_t num_inputs, uint32_t num_outputs, - std::vector *p_save_inputs, - std::vector *p_save_outputs) { + size_t num_inputs, size_t num_outputs, + std::vector *save_inputs, + std::vector *save_outputs) { static auto& fgradient = nnvm::Op::GetAttr("FGradient"); - std::vector& save_inputs = *p_save_inputs; - std::vector& save_outputs = *p_save_outputs; - save_inputs.resize(num_inputs); - save_outputs.resize(num_outputs); - std::fill(save_inputs.begin(), save_inputs.end(), false); - std::fill(save_outputs.begin(), save_outputs.end(), false); + save_inputs->resize(num_inputs); + save_outputs->resize(num_outputs); + std::fill(save_inputs->begin(), save_inputs->end(), false); + std::fill(save_outputs->begin(), save_outputs->end(), false); node->inputs.clear(); node->inputs.reserve(num_inputs); - for (uint32_t i = 0; i < num_inputs; ++i) { - node->inputs.emplace_back(nnvm::NodeEntry{nullptr, i, 0}); + for (size_t i = 0; i < num_inputs; ++i) { + node->inputs.emplace_back(nullptr, i, 0); } if (fgradient.count(node->op())) { std::vector ograd_entries; ograd_entries.reserve(num_outputs); - for (uint32_t i = 0; i < num_outputs; ++i) { + for (size_t i = 0; i < num_outputs; ++i) { ograd_entries.emplace_back(nullptr, i, 1); } auto igrad_entries = fgradient[node->op()](node, ograd_entries); for (const auto& i : igrad_entries) { if (i.node == nullptr && i.version == 0) { - save_inputs[i.index] = true; + (*save_inputs)[i.index] = true; } else if (i.node == node) { - save_outputs[i.index] = true; + (*save_outputs)[i.index] = true; } } DFSVisit(igrad_entries, [&](const nnvm::NodePtr& gnode) { if (!gnode || gnode == node) return; for (const auto& i : gnode->inputs) { if (i.node == nullptr && i.version == 0) { - save_inputs[i.index] = true; + (*save_inputs)[i.index] = true; } else if (i.node == node) { - save_outputs[i.index] = true; + (*save_outputs)[i.index] = true; } } }); @@ -218,7 +216,7 @@ void Imperative::RecordOp( nnvm::NodePtr node = nnvm::Node::Create(); node->attrs = std::move(attrs); - node->attrs.name = "node_" + std::to_string(node_count_++); + node->attrs.name = "node_" + std::to_string(node_count_++) + "_" + node->attrs.op->name; AGInfo& info = AGInfo::Create(node); info.state = state; info.ctx = outputs[0]->ctx(); @@ -232,16 +230,13 @@ void Imperative::RecordOp( node->inputs.resize(inputs.size()); } - std::vector& save_inputs = *p_save_inputs; - std::vector& save_outputs = *p_save_outputs; - for (size_t i = 0; i < inputs.size(); ++i) { if (AGInfo::IsNone(*(inputs[i]))) { nnvm::NodeEntry entry{nnvm::Symbol::CreateVariable( "null" + std::to_string(variable_count_++)).outputs[0].node, 0, 0}; AGInfo& input_info = AGInfo::Create(entry.node); input_info.ctx = inputs[i]->ctx(); - if (save_inputs[i]) { + if ((*p_save_inputs)[i]) { input_info.outputs.emplace_back(*inputs[i]); } else { // Put a dummy array here since it will not be used. @@ -250,11 +245,12 @@ void Imperative::RecordOp( input_info.outputs.back().dtype_ = inputs[i]->dtype(); input_info.outputs.back().storage_type_ = inputs[i]->storage_type(); } - inputs[i]->entry_ = std::move(entry); // assign last to prevent cyclic reference - } else if (save_inputs[i]) { - AGInfo::Get(inputs[i]->entry_.node).outputs[inputs[i]->entry_.index] = inputs[i]->Detach(); + inputs[i]->autograd_ = std::move(entry); // assign last to prevent cyclic reference + } else if ((*p_save_inputs)[i]) { + AGInfo::Get(inputs[i]->autograd_.node).outputs[inputs[i]->autograd_.index] = + inputs[i]->Detach(); } - node->inputs[i] = inputs[i]->entry_; + node->inputs[i] = inputs[i]->autograd_; } for (auto output : outputs) { @@ -263,8 +259,8 @@ void Imperative::RecordOp( << "recording with autograd."; } - for (uint32_t i = 0; i < outputs.size(); ++i) { - if (save_outputs[i]) { + for (size_t i = 0; i < outputs.size(); ++i) { + if ((*p_save_outputs)[i]) { info.outputs.emplace_back(outputs[i]->Detach()); } else { // Put a dummy array here since it will not be used. @@ -273,40 +269,35 @@ void Imperative::RecordOp( info.outputs.back().dtype_ = outputs[i]->dtype(); info.outputs.back().storage_type_ = outputs[i]->storage_type(); } - outputs[i]->entry_ = nnvm::NodeEntry{node, i, 0}; + outputs[i]->autograd_ = nnvm::NodeEntry{node, static_cast(i), 0}; } } -std::vector Imperative::Backward( - const std::vector& outputs, - const std::vector& ograds, - const std::vector& variables, - bool is_train, bool retain_graph, - bool create_graph) { - using namespace nnvm; - using namespace imperative; - static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; - static const Op* copy_op = Op::Get("_copy"); - - // Construct forward graph - Graph graph; - graph.outputs.reserve(outputs.size()); - for (const auto& i : outputs) { +nnvm::Graph Imperative::CreateGraph(const std::vector &outputs) { + nnvm::Graph g; + std::vector output_nodes; + output_nodes.reserve(outputs.size()); + for (const auto &i : outputs) { CHECK(!AGInfo::IsNone(*i)) << "Cannot differentiate node because it is not in a computational graph. " << "You need to set is_recording to true or use autograd.record() to save " << "computational graphs for backward. If you want to differentiate the same " << "graph twice, you need to pass retain_graph=True to backward."; - graph.outputs.emplace_back(i->entry_); + g.outputs.emplace_back(i->autograd_); } - size_t num_forward_outputs = graph.outputs.size(); + return g; +} - // Prepare head gradients +std::vector Imperative::CreateHeadGradientNodes( + const std::vector &outputs, + const std::vector &ograds) { + using nnvm::NodeEntry; + using nnvm::Node; std::vector ograd_entries; ograd_entries.reserve(ograds.size()); for (size_t i = 0; i < outputs.size(); ++i) { - ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0}); - AGInfo& info = AGInfo::Create(ograd_entries.back().node); + ograd_entries.emplace_back(Node::Create()); + AGInfo &info = AGInfo::Create(ograd_entries.back().node); info.ctx = outputs[i]->ctx(); if (ograds[i] != nullptr) { info.outputs.emplace_back(*ograds[i]); @@ -318,170 +309,209 @@ std::vector Imperative::Backward( } } } + return ograd_entries; +} - // Get gradient graph - Symbol sym; - sym.outputs = graph.outputs; - std::vector xs; - std::vector x_grads; - std::vector x_reqs; - if (variables.size()) { - xs.reserve(variables.size()); - x_grads.reserve(variables.size()); - x_reqs.reserve(variables.size()); +struct Imperative::GradientVariableNodes { + std::vector variable_nodes; + std::vector gradients; + std::vector op_req_types; +}; + +Imperative::GradientVariableNodes Imperative::CreateGradientVariableNodes( + const std::vector &variables, + const std::vector &outputs) { + GradientVariableNodes var_nodes; + if (!variables.empty()) { + var_nodes.variable_nodes.reserve(variables.size()); + var_nodes.gradients.reserve(variables.size()); + var_nodes.op_req_types.reserve(variables.size()); for (size_t i = 0; i < variables.size(); ++i) { CHECK(!AGInfo::IsNone(*variables[i]) && - AGInfo::IsVariable(variables[i]->entry_.node)) + AGInfo::IsVariable(variables[i]->autograd_.node)) << "Cannot differentiate with respect to the " << i+1 << "-th variable" - << " because it does not require gradient."; - xs.emplace_back(variables[i]->entry_); - x_grads.push_back(new NDArray()); - x_reqs.push_back(kWriteTo); + << " because it does not require gradient. Did you forget attach_grad()?"; + var_nodes.variable_nodes.emplace_back(variables[i]->autograd_); + var_nodes.gradients.push_back(new NDArray()); + var_nodes.op_req_types.push_back(kWriteTo); } } else { - std::vector args = sym.ListInputs(Symbol::kReadOnlyArgs); - xs.reserve(args.size()); - x_grads.reserve(args.size()); - x_reqs.reserve(args.size()); - for (const auto& i : args) { - AGInfo& info = AGInfo::Get(i); - if (info.grad_req == kNullOp) continue; - xs.emplace_back(NodeEntry{i, 0, 0}); - x_grads.push_back(&info.out_grads[0]); - x_reqs.push_back(info.grad_req); - info.fresh_out_grad = true; + nnvm::Symbol s; + s.outputs = outputs; + std::vector input_ro_nodes = s.ListInputs(Symbol::kReadOnlyArgs); + var_nodes.variable_nodes.reserve(input_ro_nodes.size()); + var_nodes.gradients.reserve(input_ro_nodes.size()); + var_nodes.op_req_types.reserve(input_ro_nodes.size()); + for (const auto& node : input_ro_nodes) { + AGInfo& info = AGInfo::Get(node); + if (info.grad_req != kNullOp) { + var_nodes.variable_nodes.emplace_back(node); + var_nodes.gradients.push_back(&info.out_grads[0]); + var_nodes.op_req_types.push_back(info.grad_req); + info.fresh_out_grad = true; + } } - CHECK_GT(xs.size(), 0) + CHECK_GT(var_nodes.variable_nodes.size(), 0) << "There are no inputs in computation graph that require gradients."; } + return var_nodes; +} - Graph g_graph = pass::MXGradient( - graph, graph.outputs, xs, ograd_entries, +std::vector Imperative::Backward( + const std::vector& outputs, + const std::vector& ograds, + const std::vector& variables, + bool is_train, bool retain_graph, + bool create_graph) { + using namespace nnvm; + using namespace imperative; + static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; + static const Op* copy_op = Op::Get("_copy"); + + Graph graph = CreateGraph(outputs); + + // Prepare head gradient nodes + std::vector ograd_entries = CreateHeadGradientNodes(outputs, ograds); + + // Get variable nodes + GradientVariableNodes gvars = CreateGradientVariableNodes(variables, graph.outputs); + + // Run backward on the graph + Graph gradient_graph = pass::MXGradient( + graph, graph.outputs, gvars.variable_nodes, ograd_entries, exec::AggregateGradient, nullptr, nullptr, zero_ops, "_copy"); - CHECK_EQ(g_graph.outputs.size(), xs.size()); - for (const auto& e : g_graph.outputs) { - if (e.node->op() == nullptr) { + + CHECK_EQ(gradient_graph.outputs.size(), gvars.variable_nodes.size()); + std::vector forward_outputs = graph.outputs; + const size_t num_forward_outputs = graph.outputs.size(); + + // TODO(larroy): move inside pass::MXGradient + for (const auto& backward_node : gradient_graph.outputs) { + if (backward_node.node->is_variable()) { auto node = Node::Create(); node->attrs.op = copy_op; - node->inputs.push_back(e); + node->inputs.push_back(backward_node); graph.outputs.emplace_back(std::move(node)); } else { - graph.outputs.push_back(e); + graph.outputs.push_back(backward_node); } } - const auto& idx = graph.indexed_graph(); + + auto& indexed_graph = graph.indexed_graph(); // get number of nodes used in forward pass size_t num_forward_nodes = 0; size_t num_forward_entries = 0; for (size_t i = 0; i < num_forward_outputs; ++i) { num_forward_nodes = std::max( - num_forward_nodes, static_cast(idx.outputs()[i].node_id + 1)); + num_forward_nodes, static_cast(indexed_graph.outputs()[i].node_id + 1)); num_forward_entries = std::max( - num_forward_entries, static_cast(idx.entry_id(idx.outputs()[i])) + 1); + num_forward_entries, static_cast(indexed_graph.entry_id( + indexed_graph.outputs()[i])) + 1); } // Allocate buffer - std::vector buff(idx.num_node_entries()); + std::vector buff(indexed_graph.num_node_entries()); std::vector ref_count(buff.size(), 0); std::vector states; std::vector arrays; arrays.reserve(buff.size()); - for (auto& buffered_array : buff) { + for (auto& buffered_array : buff) arrays.push_back(&buffered_array); - } + if (create_graph) { states.resize(num_forward_nodes); - nnvm::DFSVisit(sym.outputs, [&](const nnvm::NodePtr& n) { - AGInfo& info = AGInfo::Get(n); - states[idx.node_id(n.get())] = info.state; - for (uint32_t i = 0; i < info.outputs.size(); ++i) { - CHECK(idx.exist(n.get())); - size_t nid = idx.node_id(n.get()); - size_t eid = idx.entry_id(nid, i); + nnvm::DFSVisit(forward_outputs, [&](const nnvm::NodePtr& n) { + const AGInfo& info = AGInfo::Get(n); + states.at(indexed_graph.node_id(n.get())) = info.state; + for (size_t i = 0; i < info.outputs.size(); ++i) { + CHECK(indexed_graph.exist(n.get())); + const size_t nid = indexed_graph.node_id(n.get()); + const size_t eid = indexed_graph.entry_id(nid, i); buff[eid] = info.outputs[i]; - buff[eid].entry_ = NodeEntry{n, i, 0}; + buff[eid].autograd_ = NodeEntry{n, static_cast(i), 0}; ref_count[eid] = 1; } }); for (auto& ograd_entry : ograd_entries) { - AGInfo& info = AGInfo::Get(ograd_entry.node); - if (!idx.exist(ograd_entry.node.get())) continue; - size_t eid = idx.entry_id(ograd_entry); + const AGInfo& info = AGInfo::Get(ograd_entry.node); + if (!indexed_graph.exist(ograd_entry.node.get())) continue; + size_t eid = indexed_graph.entry_id(ograd_entry); buff[eid] = info.outputs[0]; - buff[eid].entry_ = ograd_entry; + buff[eid].autograd_ = ograd_entry; } } else { states.reserve(num_forward_nodes); for (size_t i = 0; i < num_forward_nodes; ++i) { - const AGInfo& info = dmlc::get(idx[i].source->info); + // TODO(larroy): This is a code smell 💩 + AGInfo& info = const_cast(dmlc::get(indexed_graph[i].source->info)); states.emplace_back(info.state); for (size_t j = 0; j < info.outputs.size(); ++j) { - size_t eid = idx.entry_id(i, j); - arrays[eid] = const_cast(&(info.outputs[j])); - - if (retain_graph || info.grad_req != kNullOp) ref_count[eid] = 1; + const size_t eid = indexed_graph.entry_id(i, j); + arrays[eid] = &(info.outputs[j]); + if (retain_graph || info.grad_req != kNullOp) + ref_count[eid] = 1; } } for (auto& ograd_entry : ograd_entries) { - if (!idx.exist(ograd_entry.node.get())) continue; + if (!indexed_graph.exist(ograd_entry.node.get())) continue; AGInfo& info = AGInfo::Get(ograd_entry.node); - arrays[idx.entry_id(ograd_entry)] = &info.outputs[0]; + arrays[indexed_graph.entry_id(ograd_entry)] = &info.outputs[0]; } } for (size_t i = num_forward_outputs; i < graph.outputs.size(); ++i) { - size_t eid = idx.entry_id(graph.outputs[i]); - arrays[eid] = x_grads[i - num_forward_outputs]; + size_t eid = indexed_graph.entry_id(graph.outputs[i]); + arrays[eid] = gvars.gradients[i - num_forward_outputs]; ref_count[eid] = 1; } // Assign context - auto vctx = PlaceDevice(idx); + auto vctx = PlaceDevice(indexed_graph); // Infer shape type { std::pair node_range, entry_range; - node_range = {num_forward_nodes, idx.num_nodes()}; - entry_range = {num_forward_entries, idx.num_node_entries()}; + node_range = {num_forward_nodes, indexed_graph.num_nodes()}; + entry_range = {num_forward_entries, indexed_graph.num_node_entries()}; ShapeVector shapes; - shapes.reserve(idx.num_node_entries()); + shapes.reserve(indexed_graph.num_node_entries()); bool contain_unknown = false; for (const auto& i : arrays) shapes.emplace_back(i->shape()); CheckAndInferShape(&graph, std::move(shapes), false, node_range, entry_range, &contain_unknown); DTypeVector dtypes; - dtypes.reserve(idx.num_node_entries()); + dtypes.reserve(indexed_graph.num_node_entries()); for (const auto& i : arrays) dtypes.emplace_back(i->dtype()); CheckAndInferType(&graph, std::move(dtypes), false, node_range, entry_range); StorageTypeVector stypes; - stypes.reserve(idx.num_node_entries()); + stypes.reserve(indexed_graph.num_node_entries()); for (const auto& i : arrays) stypes.emplace_back(i->storage_type()); exec::DevMaskVector dev_mask; - dev_mask.reserve(idx.num_nodes()); + dev_mask.reserve(indexed_graph.num_nodes()); for (const auto& i : vctx) dev_mask.emplace_back(i.dev_mask()); CheckAndInferStorageType(&graph, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); } // Calculate ref count - for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { - for (const auto& j : idx[i].inputs) { - ++ref_count[idx.entry_id(j)]; + for (size_t i = num_forward_nodes; i < indexed_graph.num_nodes(); ++i) { + for (const auto& j : indexed_graph[i].inputs) { + ++ref_count[indexed_graph.entry_id(j)]; } } // Assign reqs std::vector array_reqs(arrays.size(), kWriteTo); - for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) { + for (size_t i = num_forward_entries; i < indexed_graph.num_node_entries(); ++i) { if (ref_count[i] == 0) array_reqs[i] = kNullOp; } - for (size_t i = num_forward_outputs; i < idx.outputs().size(); ++i) { - size_t eid = idx.entry_id(idx.outputs()[i]); - array_reqs[eid] = x_reqs[i - num_forward_outputs]; + for (size_t i = num_forward_outputs; i < indexed_graph.outputs().size(); ++i) { + size_t eid = indexed_graph.entry_id(indexed_graph.outputs()[i]); + array_reqs[eid] = gvars.op_req_types[i - num_forward_outputs]; } const auto& shapes = graph.GetAttr("shape"); @@ -489,10 +519,10 @@ std::vector Imperative::Backward( const auto& stypes = graph.GetAttr("storage_type"); const auto& dispatch_modes = graph.GetAttr("dispatch_mode"); - for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { - auto num_outputs = idx[i].source->num_outputs(); + for (size_t i = num_forward_nodes; i < indexed_graph.num_nodes(); ++i) { + auto num_outputs = indexed_graph[i].source->num_outputs(); for (size_t j = 0; j < num_outputs; ++j) { - auto eid = idx.entry_id(i, j); + auto eid = indexed_graph.entry_id(i, j); if (!arrays[eid]->is_none()) continue; if (stypes[eid] == kDefaultStorage) { *arrays[eid] = NDArray(shapes[eid], vctx[i], true, dtypes[eid]); @@ -514,7 +544,7 @@ std::vector Imperative::Backward( int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); try { - RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), + RunGraph(retain_graph, indexed_graph, arrays, num_forward_nodes, indexed_graph.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, is_recording()); } catch (const dmlc::Error& e) { @@ -530,14 +560,14 @@ std::vector Imperative::Backward( // Clear history if (!retain_graph) { - nnvm::DFSVisit(sym.outputs, [&](const nnvm::NodePtr& n) { + nnvm::DFSVisit(forward_outputs, [&](const nnvm::NodePtr& n) { AGInfo::Clear(n); n->inputs.clear(); }); } if (variables.size()) { - return x_grads; + return gvars.gradients; } return {}; } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 21caafa124f9..aeb8cfbf368d 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -763,7 +763,7 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { // forward pass for (size_t i = 0; i < idx.num_nodes(); ++i) { if (!idx[i].source->info.empty()) { - vctx[i] = dmlc::get(idx[i].source->info).ctx; + vctx[i] = dmlc::get(idx[i].source->info).ctx; } else if (idx[i].source->op() == _copyto) { CHECK_GT(idx[i].source->control_deps.size(), 0); auto fwd_nid = idx.node_id(idx[i].source->control_deps[0].get()); @@ -1011,7 +1011,7 @@ inline void CreateEngineOpSeg( if (stop && nid > seg_start) { auto& seg = (*opr_segs)[seg_start]; if (seg_execs.size()) { - seg = EngineOprSeg{false, nid}; + seg = EngineOprSeg{false, nid, nullptr}; seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); } else { seg = EngineOprSeg{true, nid, nullptr}; @@ -1028,7 +1028,7 @@ inline void CreateEngineOpSeg( seg_execs.clear(); seg_start = nid + 1; } else if (is_async) { - seg = EngineOprSeg{false, nid + 1}; + seg = EngineOprSeg{false, nid + 1, nullptr}; seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); seg_execs.clear(); seg_start = nid + 1; @@ -1038,7 +1038,7 @@ inline void CreateEngineOpSeg( if (end_nid > seg_start) { auto& seg = (*opr_segs)[seg_start]; if (seg_execs.size()) { - seg = EngineOprSeg{false, end_nid}; + seg = EngineOprSeg{false, end_nid, nullptr}; seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); } else { seg = EngineOprSeg{true, end_nid, nullptr}; diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 279690b594e6..a1f2f6c48044 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -126,7 +126,7 @@ class BatchLoader : public IIterator { } return false; } - virtual const TBlobBatch &Value(void) const { + virtual const TBlobBatch& Value(void) const { return out_; } diff --git a/src/io/iter_libsvm.cc b/src/io/iter_libsvm.cc index 3decc7b33e04..27bb546fe069 100644 --- a/src/io/iter_libsvm.cc +++ b/src/io/iter_libsvm.cc @@ -144,16 +144,16 @@ class LibSVMIter: public SparseIIterator { return true; } - virtual const DataInst &Value(void) const { + virtual const DataInst& Value(void) const { return out_; } - virtual const NDArrayStorageType GetStorageType(bool is_data) const { + virtual NDArrayStorageType GetStorageType(bool is_data) const { if (is_data) return kCSRStorage; return param_.label_shape.Size() > 1 ? kCSRStorage : kDefaultStorage; } - virtual const mxnet::TShape GetShape(bool is_data) const { + virtual mxnet::TShape GetShape(bool is_data) const { if (is_data) return param_.data_shape; return param_.label_shape; } diff --git a/src/io/iter_sparse.h b/src/io/iter_sparse.h index 22b1836be419..bb61990ef04f 100644 --- a/src/io/iter_sparse.h +++ b/src/io/iter_sparse.h @@ -36,9 +36,9 @@ template class SparseIIterator : public IIterator { public: /*! \brief storage type of the data or label */ - virtual const NDArrayStorageType GetStorageType(bool is_data) const = 0; + virtual NDArrayStorageType GetStorageType(bool is_data) const = 0; /*! \brief shape of the data or label */ - virtual const mxnet::TShape GetShape(bool is_data) const = 0; + virtual mxnet::TShape GetShape(bool is_data) const = 0; }; // class SparseIIterator } // namespace mxnet diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h index c0d856df89ec..bd75b3ac1377 100644 --- a/src/io/iter_sparse_batchloader.h +++ b/src/io/iter_sparse_batchloader.h @@ -104,11 +104,11 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator return BatchLoader::Value(); } - virtual const NDArrayStorageType GetStorageType(bool is_data) const { + virtual NDArrayStorageType GetStorageType(bool is_data) const { return sparse_base_->GetStorageType(is_data); } - virtual const mxnet::TShape GetShape(bool is_data) const { + virtual mxnet::TShape GetShape(bool is_data) const { mxnet::TShape inst_shape = sparse_base_->GetShape(is_data); std::vector shape_vec; shape_vec.push_back(param_.batch_size); diff --git a/src/io/iter_sparse_prefetcher.h b/src/io/iter_sparse_prefetcher.h index 3f06052b0292..536f54fcddff 100644 --- a/src/io/iter_sparse_prefetcher.h +++ b/src/io/iter_sparse_prefetcher.h @@ -130,11 +130,11 @@ class SparsePrefetcherIter : public PrefetcherIter { return PrefetcherIter::Value(); } - virtual const NDArrayStorageType GetStorageType(bool is_data) const { + virtual NDArrayStorageType GetStorageType(bool is_data) const { return sparse_loader_->GetStorageType(is_data); } - virtual const mxnet::TShape GetShape(bool is_data) const { + virtual mxnet::TShape GetShape(bool is_data) const { return sparse_loader_->GetShape(is_data); } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 7fca6aa3f733..ba5a42c747c9 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -53,7 +53,7 @@ namespace mxnet { NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Context ctx, bool delay_alloc, int dtype, std::vector aux_types, mxnet::ShapeVector aux_shapes, mxnet::TShape storage_shape) : shape_(shape), - dtype_(dtype), storage_type_(stype), entry_(nullptr) { + dtype_(dtype), storage_type_(stype), autograd_(nullptr) { // Assign default aux types if not given if (aux_types.size() == 0 && stype != kDefaultStorage) { @@ -158,8 +158,8 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) { } NDArray NDArray::grad() const { - if (Imperative::AGInfo::IsNone(*this)) return NDArray(); - Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node); + if (AGInfo::IsNone(*this)) return NDArray(); + AGInfo& info = AGInfo::Get(autograd_.node); if (info.out_grads.size()) { CHECK_EQ(info.out_grads.size(), 1); return info.out_grads[0]; @@ -168,17 +168,17 @@ NDArray NDArray::grad() const { } nnvm::Symbol NDArray::get_autograd_symbol() const { - CHECK(!Imperative::AGInfo::IsNone(*this)) + CHECK(!AGInfo::IsNone(*this)) << "NDArray is not part of a computation graph. Did you forget to turn on recording?"; nnvm::Symbol ret; - ret.outputs.emplace_back(entry_); + ret.outputs.emplace_back(autograd_); return ret; } #if MXNET_USE_MKLDNN == 1 NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) - : storage_type_(kDefaultStorage), entry_(nullptr) { + : storage_type_(kDefaultStorage), autograd_(nullptr) { auto mem_desc = mem_pd.desc(); shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); dtype_ = get_mxnet_type(mem_desc.data.data_type); @@ -188,7 +188,7 @@ NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) } NDArray::NDArray(const std::shared_ptr &mkldnn_mem) - : storage_type_(kDefaultStorage), entry_(nullptr) { + : storage_type_(kDefaultStorage), autograd_(nullptr) { auto mem_pd = mkldnn_mem->get_primitive_desc(); auto mem_desc = mem_pd.desc(); shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); @@ -378,16 +378,16 @@ NDArray NDArray::FromDLPack(const DLManagedTensor* tensor, bool transient_handle } bool NDArray::fresh_out_grad() const { - if (Imperative::AGInfo::IsNone(*this)) return false; - Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node); + if (AGInfo::IsNone(*this)) return false; + AGInfo& info = AGInfo::Get(autograd_.node); return info.fresh_out_grad; } void NDArray::set_fresh_out_grad(bool state) const { - CHECK(!Imperative::AGInfo::IsNone(*this)) + CHECK(!AGInfo::IsNone(*this)) << "NDArray has not been marked as a variable and does not have gradient state"; - Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node); + AGInfo& info = AGInfo::Get(autograd_.node); info.fresh_out_grad = state; } diff --git a/src/nnvm/gradient.cc b/src/nnvm/gradient.cc index 586027129a0b..2cde0321e373 100644 --- a/src/nnvm/gradient.cc +++ b/src/nnvm/gradient.cc @@ -190,7 +190,8 @@ Graph Gradient(Graph src) { if (grad_fun_map.contains(ptr->op())) { input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads); CHECK_EQ((*rit)->inputs.size(), input_grads.size()) - << "Gradient function not returning enough gradient"; + << "Gradient function not returning enough gradient, there should be as many gradients" + "as inputs returned."; } else if (CheckGradAllZero(out_agg_grads, zero_ops)) { for (size_t i = 0; i < fwd_node->num_inputs(); ++i) { std::ostringstream os; @@ -249,14 +250,14 @@ Graph Gradient(Graph src) { NodePtr copy_node = Node::Create(); std::ostringstream os; os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy"; - kv->second.first++; + ++kv->second.first; copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(entry.sum); if (copy_node->attrs.op->attr_parser != nullptr) { copy_node->attrs.op->attr_parser(&(copy_node->attrs)); } - unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter)); + unique_grads.emplace(NodeEntry(std::move(copy_node)), std::make_pair(1, counter)); } } else { ret.outputs[counter] = entry.sum; diff --git a/tests/cpp/misc/libinfo_test.cc b/tests/cpp/misc/libinfo_test.cc index 57f8f8d764c3..c3e5191e3c21 100644 --- a/tests/cpp/misc/libinfo_test.cc +++ b/tests/cpp/misc/libinfo_test.cc @@ -30,4 +30,5 @@ using namespace mxnet::features; TEST(RuntimeTest, RuntimeTestAll) { EXPECT_EQ(EnumNames::names.size(), MAX_FEATURES); const auto& features = LibInfo::getInstance()->getFeatures(); + EXPECT_FALSE(features.empty()); } diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 61955f034a71..e402a99ae69a 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -413,6 +413,7 @@ def test_get_symbol(): y = x*x + 2*z - 1 assert len(get_symbol(y).list_arguments()) == 2 + @with_seed() def test_grad_with_stype(): def check_grad_with_stype(array_stype, grad_stype, expected_stype): @@ -432,6 +433,7 @@ def check_grad_with_stype(array_stype, grad_stype, expected_stype): # check the stype of the gradient when provided check_grad_with_stype(stype, grad_stype, grad_stype) + @with_seed() def test_sparse_dot_grad(): def check_sparse_dot_grad(rhs): @@ -455,6 +457,7 @@ def check_sparse_dot_grad(rhs): dns.attach_grad(stype='row_sparse') check_sparse_dot_grad(dns) + @with_seed() def test_gradient(): x = mx.nd.ones((1,))