From 11de041e075bf057d7d2924c99316692269cb8d4 Mon Sep 17 00:00:00 2001 From: piiswrong Date: Tue, 22 Aug 2017 18:41:41 +0000 Subject: [PATCH 1/4] refactor cudnn algo reg to no use string --- dmlc-core | 2 +- mshadow | 2 +- nnvm | 2 +- python/mxnet/gluon/parameter.py | 3 +- python/mxnet/metric.py | 8 +- src/io/inst_vector.h | 10 +-- src/io/iter_mnist.cc | 2 +- src/ndarray/autograd.cc | 51 +++++------ src/operator/contrib/fft-inl.h | 6 +- src/operator/contrib/fft.cc | 10 +-- src/operator/contrib/fft.cu | 11 +-- src/operator/contrib/ifft-inl.h | 7 +- src/operator/contrib/ifft.cc | 10 +-- src/operator/contrib/ifft.cu | 11 +-- src/operator/contrib/multi_proposal-inl.h | 105 ++-------------------- src/operator/contrib/multi_proposal.cu | 6 +- src/operator/contrib/proposal-inl.h | 105 ++-------------------- src/operator/contrib/proposal.cc | 6 +- src/operator/contrib/proposal.cu | 6 +- src/operator/convolution-inl.h | 41 +++++++++ src/operator/cudnn_algoreg-inl.h | 86 ++++++++++++------ src/operator/cudnn_algoreg.cc | 13 ++- src/operator/cudnn_convolution-inl.h | 16 ++-- src/operator/cudnn_deconvolution-inl.h | 17 ++-- src/operator/deconvolution-inl.h | 45 ++++++++++ 25 files changed, 261 insertions(+), 320 deletions(-) diff --git a/dmlc-core b/dmlc-core index 71bfbd3a9460..e880afeb932d 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 71bfbd3a946075cea66ca9e19bad86dd33c19b46 +Subproject commit e880afeb932d746e55eb92e8c6eb3ff1b3697c48 diff --git a/mshadow b/mshadow index 6d75df228978..380f825b84e2 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 6d75df228978ca5f182dd707578ef704099ab5ee +Subproject commit 380f825b84e28216516377e71199a8e14f12352f diff --git a/nnvm b/nnvm index bcfbf903429d..e842c098decf 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit bcfbf903429d086f16b19b4d202788de06e45536 +Subproject commit e842c098decf9f5eb6bd84e307c58e50078596b7 diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index bef55d67e140..749051d15055 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -199,6 +199,7 @@ def _finish_deferred_init(self): def _init_impl(self, data, ctx): """Sets data and grad.""" self._data = OrderedDict() + self._ctx_list = list(ctx) for i in ctx: self._data[i] = data.copyto(i) self._init_grad() @@ -377,7 +378,7 @@ def list_ctx(self): if self._deferred_init: return self._deferred_init[1] raise RuntimeError("Parameter %s has not been initialized"%self.name) - return list(self._data.keys()) + return self._ctx_list def zero_grad(self): """Sets gradient buffer on all contexts to 0. No action is taken if diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 00cc2da61f3c..a33b00ae8ab3 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -390,13 +390,13 @@ def update(self, labels, preds): for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) - pred_label = pred_label.asnumpy().astype('int32') - label = label.asnumpy().astype('int32') + label = label.astype('int32') + pred_label = pred_label.astype('int32').as_in_context(label.context) check_label_shapes(label, pred_label) - self.sum_metric += (pred_label.flat == label.flat).sum() - self.num_inst += len(pred_label.flat) + self.sum_metric += ndarray.sum(label == pred_label).asscalar() + self.num_inst += label.size @register diff --git a/src/io/inst_vector.h b/src/io/inst_vector.h index 6dc7bdfd730a..fa66cba933a0 100644 --- a/src/io/inst_vector.h +++ b/src/io/inst_vector.h @@ -30,7 +30,6 @@ #include #include #include -#include #include #include @@ -173,16 +172,16 @@ struct TBlobBatch { } }; // struct TBlobBatch -class TBlobContainer : public mshadow::TBlob { +class TBlobContainer : public TBlob { public: TBlobContainer(void) - : mshadow::TBlob(), tensor_container_(nullptr) {} + : TBlob(), tensor_container_(nullptr) {} ~TBlobContainer() { if (tensor_container_) { release(); } } - void resize(const mshadow::TShape &shape, int type_flag) { + void resize(const TShape &shape, int type_flag) { if (tensor_container_) { CHECK_EQ(this->type_flag_, type_flag); this->shape_ = shape; @@ -192,13 +191,12 @@ class TBlobContainer : public mshadow::TBlob { this->shape_ = shape; create(); } - this->stride_ = shape_[shape_.ndim() - 1]; } private: void create() { CHECK(tensor_container_ == nullptr); - CHECK_EQ(this->dev_mask_, mshadow::cpu::kDevMask); + CHECK_EQ(this->dev_mask(), mshadow::cpu::kDevMask); MSHADOW_TYPE_SWITCH(this->type_flag_, DType, { auto tensor_container = new mshadow::TensorContainer(false); tensor_container->Resize(mshadow::Shape1(shape_.Size())); diff --git a/src/io/iter_mnist.cc b/src/io/iter_mnist.cc index 055af52aaebd..9dbedbbba448 100644 --- a/src/io/iter_mnist.cc +++ b/src/io/iter_mnist.cc @@ -103,7 +103,7 @@ class MNISTIter: public IIterator { out_.batch_size = param_.batch_size; if (param_.shuffle) this->Shuffle(); if (param_.silent == 0) { - mshadow::TShape s; + TShape s; s = batch_data_.shape_; if (param_.flat) { LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) << " images, shuffle=" diff --git a/src/ndarray/autograd.cc b/src/ndarray/autograd.cc index 5ecea5decf03..3fbb3a4924f2 100644 --- a/src/ndarray/autograd.cc +++ b/src/ndarray/autograd.cc @@ -127,10 +127,10 @@ AutogradRuntime* AutogradRuntime::Get() { } void AutogradRuntime::RecordOp(const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - std::vector *p_inputs, - std::vector *p_outputs, - const OpStatePtr& state) { + const nnvm::NodeAttrs& attrs, + std::vector *p_inputs, + std::vector *p_outputs, + const OpStatePtr& state) { static auto& fgradient = nnvm::Op::GetAttr("FGradient"); std::vector& inputs = *p_inputs; std::vector& outputs = *p_outputs; @@ -144,7 +144,6 @@ void AutogradRuntime::RecordOp(const nnvm::Op* op, << "Please call backward first to clear the graph or do this out side of " << "a record section. "; } - if (!fgradient.count(attrs.op)) return; bool need_grad = false; for (const auto& i : inputs) { if (!i.entry_.is_none()) { @@ -163,28 +162,30 @@ void AutogradRuntime::RecordOp(const nnvm::Op* op, for (uint32_t i = 0; i < inputs.size(); ++i) { nn_node->inputs.emplace_back(NodeEntry{nullptr, i, 0}); } - std::vector ograd_entries; - for (uint32_t i = 0; i < outputs.size(); ++i) { - ograd_entries.emplace_back(NodeEntry{nullptr, i, 1}); - } - auto igrad_entries = fgradient[nn_node->op()](nn_node, ograd_entries); - for (const auto& i : igrad_entries) { - if (i.node == nullptr && i.version == 0) { - save_inputs[i.index] = true; - } else if (i.node == nn_node) { - save_outputs[i.index] = true; + if (fgradient.count(attrs.op)) { + std::vector ograd_entries; + for (uint32_t i = 0; i < outputs.size(); ++i) { + ograd_entries.emplace_back(NodeEntry{nullptr, i, 1}); } - } - DFSVisit(igrad_entries, [&](const NodePtr& node) { - if (!node || node == nn_node) return; - for (const auto& i : node->inputs) { - if (i.node == nullptr && i.version == 0) { - save_inputs[i.index] = true; - } else if (i.node == nn_node) { - save_outputs[i.index] = true; - } + auto igrad_entries = fgradient[nn_node->op()](nn_node, ograd_entries); + for (const auto& i : igrad_entries) { + if (i.node == nullptr && i.version == 0) { + save_inputs[i.index] = true; + } else if (i.node == nn_node) { + save_outputs[i.index] = true; } - }); + } + DFSVisit(igrad_entries, [&](const NodePtr& node) { + if (!node || node == nn_node) return; + for (const auto& i : node->inputs) { + if (i.node == nullptr && i.version == 0) { + save_inputs[i.index] = true; + } else if (i.node == nn_node) { + save_outputs[i.index] = true; + } + } + }); + } AGNodePtr ag_node = AGNode::Create(nn_node); ag_node->state = state; diff --git a/src/operator/contrib/fft-inl.h b/src/operator/contrib/fft-inl.h index 5092f586fdf7..12474f183e84 100644 --- a/src/operator/contrib/fft-inl.h +++ b/src/operator/contrib/fft-inl.h @@ -54,6 +54,7 @@ struct FFTParam : public dmlc::Parameter { } }; +#if MXNET_USE_CUDA template class FFTOp : public Operator { public: @@ -102,7 +103,6 @@ class FFTOp : public Operator { Shape1(param_.compute_size*dim_*2), s); Tensor complex_data = Tensor(workspace.dptr_, Shape2(param_.compute_size, dim_*2), s); - #if MSHADOW_USE_CUDNN // start fft cufftHandle plan; cufftPlanMany(&plan, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, param_.compute_size); @@ -135,7 +135,6 @@ class FFTOp : public Operator { CHECK_EQ(cufftExecC2C(plan_remain, in_tmp, out_tmp, CUFFT_FORWARD), CUFFT_SUCCESS); cufftDestroy(plan_remain); } - #endif } virtual void Backward(const OpContext &ctx, @@ -170,7 +169,6 @@ class FFTOp : public Operator { // In this solution, out_grad must comes from a fft of real signal, // so that it is Hermitian symmetric, giving a real output // but if it is not, remember that we have implemented complex_take_real, and use this - #if MSHADOW_USE_CUDNN cufftHandle plan; cufftPlanMany(&plan, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, param_.compute_size); for (size_t idx = 0; idx < num_compute; ++idx) { @@ -203,7 +201,6 @@ class FFTOp : public Operator { req[fft::kData], complex_toreal(complex_data)); cufftDestroy(plan_remain); } - #endif // for bp, we should not divide it // but for comparison with np.fft.ifft, we should do it. // gdata /= dim_; @@ -214,6 +211,7 @@ class FFTOp : public Operator { int dim_, stride_, num_compute, n_ffts; bool init_cufft_; }; // class FFTOp +#endif // MXNET_USE_CUDA // Declare Factory Function, used for dispatch specialization template diff --git a/src/operator/contrib/fft.cc b/src/operator/contrib/fft.cc index 11f8425e07b1..6f78003baebb 100644 --- a/src/operator/contrib/fft.cc +++ b/src/operator/contrib/fft.cc @@ -28,17 +28,13 @@ namespace mxnet { namespace op { template<> Operator *CreateOp(FFTParam param, int dtype) { - LOG(FATAL) << "fft is only available for GPU."; - return NULL; + LOG(FATAL) << "fft is only available for GPU."; + return NULL; } Operator *FFTProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } DMLC_REGISTER_PARAMETER(FFTParam); diff --git a/src/operator/contrib/fft.cu b/src/operator/contrib/fft.cu index 3017ce76756b..dfe3fbba6124 100644 --- a/src/operator/contrib/fft.cu +++ b/src/operator/contrib/fft.cu @@ -29,11 +29,12 @@ namespace op { template<> Operator* CreateOp(FFTParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new FFTOp(param); - }) - return op; + Operator *op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new FFTOp(param); + }) + return op; } + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/ifft-inl.h b/src/operator/contrib/ifft-inl.h index abd5bb22a389..5e89c5b644ce 100644 --- a/src/operator/contrib/ifft-inl.h +++ b/src/operator/contrib/ifft-inl.h @@ -54,6 +54,7 @@ struct IFFTParam : public dmlc::Parameter { } }; +#if MXNET_USE_CUDA template class IFFTOp : public Operator { public: @@ -98,7 +99,6 @@ class IFFTOp : public Operator { Shape1(param_.compute_size*dim_*2), s); Tensor complex_data = Tensor(workspace.dptr_, Shape2(param_.compute_size, dim_*2), s); - #if MSHADOW_USE_CUDNN // start ifft cufftHandle plan; cufftPlanMany(&plan, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, param_.compute_size); @@ -131,7 +131,6 @@ class IFFTOp : public Operator { req[ifft::kOut], complex_toreal(complex_data)); cufftDestroy(plan_remain); } - #endif // commenting this out to be consistant with caffe // out /= dim_; } @@ -162,7 +161,6 @@ class IFFTOp : public Operator { Shape1(param_.compute_size*dim_*2), s); Tensor complex_data = Tensor(workspace.dptr_, Shape2(param_.compute_size, dim_*2), s); - #if MSHADOW_USE_CUDNN // start fft cufftHandle plan; cufftPlanMany(&plan, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, param_.compute_size); @@ -194,7 +192,6 @@ class IFFTOp : public Operator { CHECK_EQ(cufftExecC2C(plan_remain, in_tmp, out_tmp, CUFFT_FORWARD), CUFFT_SUCCESS); cufftDestroy(plan_remain); } - #endif // commenting this out to be consistant with caffe // gdata /= dim_; } @@ -205,6 +202,8 @@ class IFFTOp : public Operator { bool init_cufft_; }; // class IFFTOp +#endif // MXNET_USE_CUDA + // Declare Factory Function, used for dispatch specialization template Operator* CreateOp(IFFTParam param, int dtype); diff --git a/src/operator/contrib/ifft.cc b/src/operator/contrib/ifft.cc index 0ea3a7ec112f..95c79a785a16 100644 --- a/src/operator/contrib/ifft.cc +++ b/src/operator/contrib/ifft.cc @@ -29,17 +29,13 @@ namespace op { template<> Operator *CreateOp(IFFTParam param, int dtype) { - LOG(FATAL) << "ifft is only available for GPU."; - return NULL; + LOG(FATAL) << "ifft is only available for GPU."; + return NULL; } Operator *IFFTProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } DMLC_REGISTER_PARAMETER(IFFTParam); diff --git a/src/operator/contrib/ifft.cu b/src/operator/contrib/ifft.cu index 79795d8561bf..35cdb4836b37 100644 --- a/src/operator/contrib/ifft.cu +++ b/src/operator/contrib/ifft.cu @@ -29,11 +29,12 @@ namespace op { template<> Operator* CreateOp(IFFTParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new IFFTOp(param); - }) - return op; + Operator *op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new IFFTOp(param); + }) + return op; } + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/multi_proposal-inl.h b/src/operator/contrib/multi_proposal-inl.h index 7cd465e0b09e..ddfe0628f306 100644 --- a/src/operator/contrib/multi_proposal-inl.h +++ b/src/operator/contrib/multi_proposal-inl.h @@ -40,95 +40,6 @@ #include "../operator_common.h" #include "../mshadow_op.h" -// extend NumericalParam -namespace mxnet { -namespace op { - -/*! -* \brief structure for numerical tuple input -* \tparam VType data type of param -*/ -template -struct NumericalParam { - NumericalParam() {} - explicit NumericalParam(VType *begin, VType *end) { - int32_t size = static_cast(end - begin); - info.resize(size); - for (int i = 0; i < size; ++i) { - info[i] = *(begin + i); - } - } - inline size_t ndim() const { - return info.size(); - } - std::vector info; -}; - -template -inline std::istream &operator>>(std::istream &is, NumericalParam ¶m) { - while (true) { - char ch = is.get(); - if (ch == '(') break; - if (!isspace(ch)) { - is.setstate(std::ios::failbit); - return is; - } - } - VType idx; - std::vector tmp; - // deal with empty case - size_t pos = is.tellg(); - char ch = is.get(); - if (ch == ')') { - param.info = tmp; - return is; - } - is.seekg(pos); - // finish deal - while (is >> idx) { - tmp.push_back(idx); - char ch; - do { - ch = is.get(); - } while (isspace(ch)); - if (ch == ',') { - while (true) { - ch = is.peek(); - if (isspace(ch)) { - is.get(); continue; - } - if (ch == ')') { - is.get(); break; - } - break; - } - if (ch == ')') break; - } else if (ch == ')') { - break; - } else { - is.setstate(std::ios::failbit); - return is; - } - } - param.info = tmp; - return is; -} - -template -inline std::ostream &operator<<(std::ostream &os, const NumericalParam ¶m) { - os << '('; - for (index_t i = 0; i < param.info.size(); ++i) { - if (i != 0) os << ','; - os << param.info[i]; - } - // python style tuple - if (param.info.size() == 1) os << ','; - os << ')'; - return os; -} - -} // namespace op -} // namespace mxnet namespace mxnet { namespace op { @@ -144,8 +55,8 @@ struct MultiProposalParam : public dmlc::Parameter { int rpn_post_nms_top_n; float threshold; int rpn_min_size; - NumericalParam scales; - NumericalParam ratios; + nnvm::Tuple scales; + nnvm::Tuple ratios; int feature_stride; bool output_score; bool iou_loss; @@ -161,10 +72,10 @@ struct MultiProposalParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(rpn_min_size).set_default(16) .describe("Minimum height or width in proposal"); tmp[0] = 4.0f; tmp[1] = 8.0f; tmp[2] = 16.0f; tmp[3] = 32.0f; - DMLC_DECLARE_FIELD(scales).set_default(NumericalParam(tmp, tmp + 4)) + DMLC_DECLARE_FIELD(scales).set_default(nnvm::Tuple(tmp, tmp + 4)) .describe("Used to generate anchor windows by enumerating scales"); tmp[0] = 0.5f; tmp[1] = 1.0f; tmp[2] = 2.0f; - DMLC_DECLARE_FIELD(ratios).set_default(NumericalParam(tmp, tmp + 3)) + DMLC_DECLARE_FIELD(ratios).set_default(nnvm::Tuple(tmp, tmp + 3)) .describe("Used to generate anchor windows by enumerating ratios"); DMLC_DECLARE_FIELD(feature_stride).set_default(16) .describe("The size of the receptive field each unit in the convolution layer of the rpn," @@ -302,11 +213,11 @@ inline void _Transform(float scale, // out_anchors must have shape (n, 5), where n is ratios.size() * scales.size() inline void GenerateAnchors(const std::vector& base_anchor, - const std::vector& ratios, - const std::vector& scales, + const nnvm::Tuple& ratios, + const nnvm::Tuple& scales, std::vector *out_anchors) { - for (size_t j = 0; j < ratios.size(); ++j) { - for (size_t k = 0; k < scales.size(); ++k) { + for (size_t j = 0; j < ratios.ndim(); ++j) { + for (size_t k = 0; k < scales.ndim(); ++k) { _Transform(scales[k], ratios[j], base_anchor, out_anchors); } } diff --git a/src/operator/contrib/multi_proposal.cu b/src/operator/contrib/multi_proposal.cu index cb9996344e3e..082de6a397a7 100644 --- a/src/operator/contrib/multi_proposal.cu +++ b/src/operator/contrib/multi_proposal.cu @@ -460,11 +460,11 @@ class MultiProposalGPUOp : public Operator{ base_anchor[1] = 0.0; base_anchor[2] = param_.feature_stride - 1.0; base_anchor[3] = param_.feature_stride - 1.0; - CHECK_EQ(num_anchors, param_.ratios.info.size() * param_.scales.info.size()); + CHECK_EQ(num_anchors, param_.ratios.ndim() * param_.scales.ndim()); std::vector anchors; utils::GenerateAnchors(base_anchor, - param_.ratios.info, - param_.scales.info, + param_.ratios, + param_.scales, &anchors); // Copy generated anchors to GPU diff --git a/src/operator/contrib/proposal-inl.h b/src/operator/contrib/proposal-inl.h index 3d1851cedbac..f989cdec3767 100644 --- a/src/operator/contrib/proposal-inl.h +++ b/src/operator/contrib/proposal-inl.h @@ -38,95 +38,6 @@ #include "../operator_common.h" #include "../mshadow_op.h" -// extend NumericalParam -namespace mxnet { -namespace op { - -/*! -* \brief structure for numerical tuple input -* \tparam VType data type of param -*/ -template -struct NumericalParam { - NumericalParam() {} - explicit NumericalParam(VType *begin, VType *end) { - int32_t size = static_cast(end - begin); - info.resize(size); - for (int i = 0; i < size; ++i) { - info[i] = *(begin + i); - } - } - inline size_t ndim() const { - return info.size(); - } - std::vector info; -}; - -template -inline std::istream &operator>>(std::istream &is, NumericalParam ¶m) { - while (true) { - char ch = is.get(); - if (ch == '(') break; - if (!isspace(ch)) { - is.setstate(std::ios::failbit); - return is; - } - } - VType idx; - std::vector tmp; - // deal with empty case - size_t pos = is.tellg(); - char ch = is.get(); - if (ch == ')') { - param.info = tmp; - return is; - } - is.seekg(pos); - // finish deal - while (is >> idx) { - tmp.push_back(idx); - char ch; - do { - ch = is.get(); - } while (isspace(ch)); - if (ch == ',') { - while (true) { - ch = is.peek(); - if (isspace(ch)) { - is.get(); continue; - } - if (ch == ')') { - is.get(); break; - } - break; - } - if (ch == ')') break; - } else if (ch == ')') { - break; - } else { - is.setstate(std::ios::failbit); - return is; - } - } - param.info = tmp; - return is; -} - -template -inline std::ostream &operator<<(std::ostream &os, const NumericalParam ¶m) { - os << '('; - for (index_t i = 0; i < param.info.size(); ++i) { - if (i != 0) os << ','; - os << param.info[i]; - } - // python style tuple - if (param.info.size() == 1) os << ','; - os << ')'; - return os; -} - -} // namespace op -} // namespace mxnet namespace mxnet { namespace op { @@ -142,8 +53,8 @@ struct ProposalParam : public dmlc::Parameter { int rpn_post_nms_top_n; float threshold; int rpn_min_size; - NumericalParam scales; - NumericalParam ratios; + nnvm::Tuple scales; + nnvm::Tuple ratios; int feature_stride; bool output_score; bool iou_loss; @@ -159,10 +70,10 @@ struct ProposalParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(rpn_min_size).set_default(16) .describe("Minimum height or width in proposal"); tmp[0] = 4.0f; tmp[1] = 8.0f; tmp[2] = 16.0f; tmp[3] = 32.0f; - DMLC_DECLARE_FIELD(scales).set_default(NumericalParam(tmp, tmp + 4)) + DMLC_DECLARE_FIELD(scales).set_default(nnvm::Tuple(tmp, tmp + 4)) .describe("Used to generate anchor windows by enumerating scales"); tmp[0] = 0.5f; tmp[1] = 1.0f; tmp[2] = 2.0f; - DMLC_DECLARE_FIELD(ratios).set_default(NumericalParam(tmp, tmp + 3)) + DMLC_DECLARE_FIELD(ratios).set_default(nnvm::Tuple(tmp, tmp + 3)) .describe("Used to generate anchor windows by enumerating ratios"); DMLC_DECLARE_FIELD(feature_stride).set_default(16) .describe("The size of the receptive field each unit in the convolution layer of the rpn," @@ -300,11 +211,11 @@ inline void _Transform(float scale, // out_anchors must have shape (n, 5), where n is ratios.size() * scales.size() inline void GenerateAnchors(const std::vector& base_anchor, - const std::vector& ratios, - const std::vector& scales, + const nnvm::Tuple& ratios, + const nnvm::Tuple& scales, std::vector *out_anchors) { - for (size_t j = 0; j < ratios.size(); ++j) { - for (size_t k = 0; k < scales.size(); ++k) { + for (size_t j = 0; j < ratios.ndim(); ++j) { + for (size_t k = 0; k < scales.ndim(); ++k) { _Transform(scales[k], ratios[j], base_anchor, out_anchors); } } diff --git a/src/operator/contrib/proposal.cc b/src/operator/contrib/proposal.cc index ec539003b944..ccb541a403a2 100644 --- a/src/operator/contrib/proposal.cc +++ b/src/operator/contrib/proposal.cc @@ -335,11 +335,11 @@ class ProposalOp : public Operator{ base_anchor[1] = 0.0; base_anchor[2] = param_.feature_stride - 1.0; base_anchor[3] = param_.feature_stride - 1.0; - CHECK_EQ(num_anchors, param_.ratios.info.size() * param_.scales.info.size()); + CHECK_EQ(num_anchors, param_.ratios.ndim() * param_.scales.ndim()); std::vector anchors; utils::GenerateAnchors(base_anchor, - param_.ratios.info, - param_.scales.info, + param_.ratios, + param_.scales, &anchors); std::memcpy(workspace_proposals.dptr_, &anchors[0], sizeof(float) * anchors.size()); diff --git a/src/operator/contrib/proposal.cu b/src/operator/contrib/proposal.cu index 209ef79a2aaf..9f56685a7a7d 100644 --- a/src/operator/contrib/proposal.cu +++ b/src/operator/contrib/proposal.cu @@ -442,11 +442,11 @@ class ProposalGPUOp : public Operator{ base_anchor[1] = 0.0; base_anchor[2] = param_.feature_stride - 1.0; base_anchor[3] = param_.feature_stride - 1.0; - CHECK_EQ(num_anchors, param_.ratios.info.size() * param_.scales.info.size()); + CHECK_EQ(num_anchors, param_.ratios.ndim() * param_.scales.ndim()); std::vector anchors; utils::GenerateAnchors(base_anchor, - param_.ratios.info, - param_.scales.info, + param_.ratios, + param_.scales, &anchors); // Copy generated anchors to GPU diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index 0edaee1dae32..a9e2c1bd6e94 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -103,8 +103,49 @@ struct ConvolutionParam : public dmlc::Parameter { index_t DilatedKernelSize(int dim) const { return 1 + (kernel[dim] - 1) * dilate[dim]; } + + bool operator==(const ConvolutionParam& other) const { + return this->kernel == other.kernel && + this->stride == other.stride && + this->dilate == other.dilate && + this->pad == other.pad && + this->num_filter == other.num_filter && + this->num_group == other.num_group && + this->workspace == other.workspace && + this->no_bias == other.no_bias && + this->cudnn_tune == other.cudnn_tune && + this->cudnn_off == other.cudnn_off && + this->layout == other.layout; + } }; +} // namespace op +} // namespace mxnet + +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::ConvolutionParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.kernel); + ret = dmlc::HashCombine(ret, val.stride); + ret = dmlc::HashCombine(ret, val.dilate); + ret = dmlc::HashCombine(ret, val.pad); + ret = dmlc::HashCombine(ret, val.num_filter); + ret = dmlc::HashCombine(ret, val.num_group); + ret = dmlc::HashCombine(ret, val.workspace); + ret = dmlc::HashCombine(ret, val.no_bias); + ret = dmlc::HashCombine(ret, val.cudnn_tune); + ret = dmlc::HashCombine(ret, val.cudnn_off); + ret = dmlc::HashCombine(ret, val.layout); + return ret; + } +}; +} // namespace std + +namespace mxnet { +namespace op { + template class ConvolutionOp : public Operator { public: diff --git a/src/operator/cudnn_algoreg-inl.h b/src/operator/cudnn_algoreg-inl.h index dc5db6bbc8b7..219faff288f1 100644 --- a/src/operator/cudnn_algoreg-inl.h +++ b/src/operator/cudnn_algoreg-inl.h @@ -61,37 +61,21 @@ class CuDNNAlgo { bool is_tensor_core_algo_; }; +template class CuDNNAlgoReg { public: - template - std::string GetKey(const Param ¶m, const std::vector &in_shape, - const std::vector &out_shape, - cudnnDataType_t cudnn_data_type, - cudnnDataType_t cudnn_forward_compute_type, - cudnnDataType_t cudnn_backward_compute_type, - int sm_arch) { - std::ostringstream oss; - oss << "inputs="; - for (auto &i : in_shape) - oss << i << ";"; - oss << "outputs="; - for (auto &i : out_shape) - oss << i << ";"; - auto dict = param.__DICT__(); - for (auto &k : dict) - oss << k.first << "=" << k.second << ";"; - oss << "cudnn_data_type=" << cudnn_data_type << ";"; - oss << "cudnn_forward_compute_type=" << cudnn_forward_compute_type << ";"; - oss << "cudnn_backward_compute_type=" << cudnn_backward_compute_type << ";"; - // All GPUs of the same compute capability (SM arch) share an algo selection. - oss << "sm_arch=" << sm_arch << ";"; - return oss.str(); - } - - bool Find(std::string key, + bool Find(const ParamType ¶m, + const std::vector &in_shape, + const std::vector &out_shape, + cudnnDataType_t cudnn_data_type, + cudnnDataType_t cudnn_forward_compute_type, + cudnnDataType_t cudnn_backward_compute_type, + int sm_arch, CuDNNAlgo *fwd, CuDNNAlgo *bwd, CuDNNAlgo *flt) { + ParamKey key{param, in_shape, out_shape, cudnn_data_type, cudnn_forward_compute_type, + cudnn_backward_compute_type, sm_arch}; std::lock_guard guard(lock_); auto i = reg_.find(key); if (i != reg_.end()) { @@ -103,10 +87,18 @@ class CuDNNAlgoReg { return false; } - void Register(std::string key, + void Register(const ParamType ¶m, + const std::vector &in_shape, + const std::vector &out_shape, + cudnnDataType_t cudnn_data_type, + cudnnDataType_t cudnn_forward_compute_type, + cudnnDataType_t cudnn_backward_compute_type, + int sm_arch, const CuDNNAlgo &fwd, const CuDNNAlgo &bwd, const CuDNNAlgo &flt) { + ParamKey key{param, in_shape, out_shape, cudnn_data_type, cudnn_forward_compute_type, + cudnn_backward_compute_type, sm_arch}; std::lock_guard guard(lock_); if (reg_.size() % 50 == 0) { LOG(INFO) << "Running performance tests to find the best convolution " @@ -134,9 +126,47 @@ class CuDNNAlgoReg { CuDNNAlgo flt; }; + struct ParamKey { + ParamType param; + std::vector in_shape; + std::vector out_shape; + cudnnDataType_t cudnn_data_type; + cudnnDataType_t cudnn_forward_compute_type; + cudnnDataType_t cudnn_backward_compute_type; + int sm_arch; + + bool operator==(const ParamKey& other) const { + return this->param == other.param && + this->in_shape == other.in_shape && + this->out_shape == other.out_shape && + this->cudnn_data_type == other.cudnn_data_type && + this->cudnn_forward_compute_type == other.cudnn_forward_compute_type && + this->cudnn_backward_compute_type == other.cudnn_backward_compute_type && + this->sm_arch == other.sm_arch; + } + }; + + struct ParamHash { + size_t operator()(const ParamKey& key) const { + std::hash hash_param; + size_t ret = hash_param(key.param); + for (const auto& i : key.in_shape) ret = dmlc::HashCombine(ret, i); + for (const auto& i : key.out_shape) ret = dmlc::HashCombine(ret, i); + ret = dmlc::HashCombine(ret, static_cast(key.cudnn_data_type)); + ret = dmlc::HashCombine(ret, static_cast(key.cudnn_forward_compute_type)); + ret = dmlc::HashCombine(ret, static_cast(key.cudnn_backward_compute_type)); + ret = dmlc::HashCombine(ret, key.sm_arch); + return ret; + } + }; + std::mutex lock_; - std::unordered_map reg_; + std::unordered_map reg_; }; + +typedef CuDNNAlgoReg CuDNNConvAlgoReg; +typedef CuDNNAlgoReg CuDNNDeconvAlgoReg; + #endif // __CUDACC__ && CUDNN } // namespace op } // namespace mxnet diff --git a/src/operator/cudnn_algoreg.cc b/src/operator/cudnn_algoreg.cc index 5aa8688c8148..5b0e73f0b19d 100644 --- a/src/operator/cudnn_algoreg.cc +++ b/src/operator/cudnn_algoreg.cc @@ -32,9 +32,16 @@ namespace mxnet { namespace op { #if MXNET_USE_CUDNN == 1 -CuDNNAlgoReg *CuDNNAlgoReg::Get() { - static CuDNNAlgoReg *ptr = new CuDNNAlgoReg(); - return ptr; +template<> +CuDNNAlgoReg *CuDNNAlgoReg::Get() { + static CuDNNAlgoReg inst; + return &inst; +} + +template<> +CuDNNAlgoReg *CuDNNAlgoReg::Get() { + static CuDNNAlgoReg inst; + return &inst; } #endif // CUDNN } // namespace op diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index 428278498337..b2b59944e895 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -580,11 +580,10 @@ class CuDNNConvolutionOp : public Operator { const std::vector& out_shape, cudnnDataType_t cudnn_forward_compute_type, cudnnDataType_t cudnn_backward_compute_type) { - std::string key = CuDNNAlgoReg::Get()->GetKey(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(ctx.dev_id)); - if (!CuDNNAlgoReg::Get()->Find(key, &forward_algo_, &back_algo_, &back_algo_w_)) { + if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, cudnn_backward_compute_type, + SMArch(ctx.dev_id), &forward_algo_, &back_algo_, + &back_algo_w_)) { // Not in algo registry, must determine via *Get*() or *Find*() Engine::VarHandle var = Engine::Get()->NewVariable(); Engine::Get()->PushSync([=](RunContext rctx) { @@ -772,8 +771,11 @@ class CuDNNConvolutionOp : public Operator { // convolution will match only if identically specified. // We're caching results of *Get* as well as *Find*, but these records // will be held distinctly because param_.cudnn_tune is part of the key. - CuDNNAlgoReg::Get()->Register(key, this->forward_algo_, this->back_algo_, - this->back_algo_w_); + CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + SMArch(ctx.dev_id), this->forward_algo_, + this->back_algo_, this->back_algo_w_); }, ctx, {}, {var}); Engine::Get()->WaitForVar(var); Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var); diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h index de3e70c7d6a7..5e9b7c5704d0 100644 --- a/src/operator/cudnn_deconvolution-inl.h +++ b/src/operator/cudnn_deconvolution-inl.h @@ -598,11 +598,11 @@ class CuDNNDeconvolutionOp : public Operator { const std::vector& out_shape, cudnnDataType_t cudnn_forward_compute_type, cudnnDataType_t cudnn_backward_compute_type) { - std::string key = CuDNNAlgoReg::Get()->GetKey(param_, in_shape, out_shape, dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(ctx.dev_id)); - if (!CuDNNAlgoReg::Get()->Find(key, &forward_algo_, &back_algo_, &back_algo_w_)) { + if (!CuDNNDeconvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + SMArch(ctx.dev_id), &forward_algo_, + &back_algo_, &back_algo_w_)) { // Not in algo registry, must determine via *Get*() or *Find*() Engine::VarHandle var = Engine::Get()->NewVariable(); Engine::Get()->PushSync([=](RunContext rctx) { @@ -793,8 +793,11 @@ class CuDNNDeconvolutionOp : public Operator { // convolution will match only if identically specified. // We're caching results of *Get* as well as *Find*, but these records // will be held distinctly because param_.cudnn_tune is part of the key. - CuDNNAlgoReg::Get()->Register(key, this->forward_algo_, this->back_algo_, - this->back_algo_w_); + CuDNNDeconvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_, + cudnn_forward_compute_type, + cudnn_backward_compute_type, + SMArch(ctx.dev_id), this->forward_algo_, + this->back_algo_, this->back_algo_w_); }, ctx, {}, {var}); Engine::Get()->WaitForVar(var); Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var); diff --git a/src/operator/deconvolution-inl.h b/src/operator/deconvolution-inl.h index dd77c150c970..a968ce44a800 100644 --- a/src/operator/deconvolution-inl.h +++ b/src/operator/deconvolution-inl.h @@ -144,8 +144,53 @@ struct DeconvolutionParam : public dmlc::Parameter { index_t DilatedKernelSize(int dim) const { return 1 + (kernel[dim] - 1) * dilate[dim]; } + + bool operator==(const DeconvolutionParam& other) const { + return this->kernel == other.kernel && + this->stride == other.stride && + this->dilate == other.dilate && + this->pad == other.pad && + this->adj == other.adj && + this->target_shape == other.target_shape && + this->num_filter == other.num_filter && + this->num_group == other.num_group && + this->workspace == other.workspace && + this->no_bias == other.no_bias && + this->cudnn_tune == other.cudnn_tune && + this->cudnn_off == other.cudnn_off && + this->layout == other.layout; + } }; +} // namespace op +} // namespace mxnet + +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::DeconvolutionParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.kernel); + ret = dmlc::HashCombine(ret, val.stride); + ret = dmlc::HashCombine(ret, val.dilate); + ret = dmlc::HashCombine(ret, val.pad); + ret = dmlc::HashCombine(ret, val.adj); + ret = dmlc::HashCombine(ret, val.target_shape); + ret = dmlc::HashCombine(ret, val.num_filter); + ret = dmlc::HashCombine(ret, val.num_group); + ret = dmlc::HashCombine(ret, val.workspace); + ret = dmlc::HashCombine(ret, val.no_bias); + ret = dmlc::HashCombine(ret, val.cudnn_tune); + ret = dmlc::HashCombine(ret, val.cudnn_off); + ret = dmlc::HashCombine(ret, val.layout); + return ret; + } +}; +} // namespace std + +namespace mxnet { +namespace op { + template class DeconvolutionOp : public Operator { public: From 46f0814e27a9e6ce1217cbada5e0f1bf82c46282 Mon Sep 17 00:00:00 2001 From: piiswrong Date: Wed, 23 Aug 2017 18:26:01 +0000 Subject: [PATCH 2/4] refactor ctx list --- python/mxnet/gluon/parameter.py | 56 +++++++++++++++------------------ 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 749051d15055..4bc2611a70a7 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -129,14 +129,22 @@ def grad_req(self, req): elif self._data is not None: self._init_grad() - def _check_initialized(self, ctx=None): - if self._data is not None: - if ctx is not None and ctx not in self._data: - raise RuntimeError( - "Parameter %s was not initialized on context %s. " - "It was only initialized on %s."%( - self.name, str(ctx), str(self.list_ctx()))) - return + def _check_and_get(self, arr_dict, ctx): + if arr_dict is not None: + if ctx is list: + return list(arr_dict.values()) + if ctx is None: + if len(self._ctx_list) == 1: + ctx = self._ctx_list[0] + else: + ctx = context.current_context() + ret = arr_dict.get(ctx, None) + if ret is not None: + return ret + raise RuntimeError( + "Parameter %s was not initialized on context %s. " + "It was only initialized on %s."%( + self.name, str(ctx), str(self._ctx_list))) if self._deferred_init: raise DeferredInitializationError raise RuntimeError( @@ -328,20 +336,12 @@ def data(self, ctx=None): ------- NDArray on ctx """ - if ctx is None: - list_ctx = self.list_ctx() - if len(list_ctx) == 1: - ctx = list_ctx[0] - else: - ctx = context.current_context() - self._check_initialized(ctx) - return self._data[ctx] + return self._check_and_get(self._data, ctx) def list_data(self): """Returns copies of this parameter on all contexts, in the same order as creation.""" - self._check_initialized() - return list(self._data.values()) + return self._check_and_get(self._data, list) def grad(self, ctx=None): """Returns a gradient buffer for this parameter on one context. @@ -351,26 +351,20 @@ def grad(self, ctx=None): ctx : Context Desired context. """ - if ctx is None: - list_ctx = self.list_ctx() - if len(list_ctx) == 1: - ctx = list_ctx[0] - else: - ctx = context.current_context() - self._check_initialized(ctx) - if self._grad is None: + if self._data is not None and self._grad is None: raise RuntimeError( "Cannot get gradient array for Parameter %s " \ "because grad_req='null'"%(self.name)) - return self._grad[ctx] + return self._check_and_get(self._grad, ctx) def list_grad(self): """Returns gradient buffers on all contexts, in the same order as `values`.""" - self._check_initialized() - assert self._grad is not None, \ - "Parameter %s does not have gradients because grad_req='null'"%self.name - return list(self._grad.values()) + if self._data is not None and self._grad is None: + raise RuntimeError( + "Cannot get gradient array for Parameter %s " \ + "because grad_req='null'"%(self.name)) + return self._check_and_get(self._grad, list) def list_ctx(self): """Returns a list of contexts this parameter is initialized on.""" From 5c95e46123a66c8a18d3a98f009feb9fdfd1eb83 Mon Sep 17 00:00:00 2001 From: piiswrong Date: Wed, 23 Aug 2017 20:18:33 +0000 Subject: [PATCH 3/4] fix --- src/operator/cudnn_algoreg-inl.h | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/operator/cudnn_algoreg-inl.h b/src/operator/cudnn_algoreg-inl.h index 219faff288f1..b27d2be297fe 100644 --- a/src/operator/cudnn_algoreg-inl.h +++ b/src/operator/cudnn_algoreg-inl.h @@ -74,8 +74,9 @@ class CuDNNAlgoReg { CuDNNAlgo *fwd, CuDNNAlgo *bwd, CuDNNAlgo *flt) { - ParamKey key{param, in_shape, out_shape, cudnn_data_type, cudnn_forward_compute_type, - cudnn_backward_compute_type, sm_arch}; + CHECK(in_shape.size() == 2 || in_shape.size() == 3); + ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type, + cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch}; std::lock_guard guard(lock_); auto i = reg_.find(key); if (i != reg_.end()) { @@ -97,8 +98,9 @@ class CuDNNAlgoReg { const CuDNNAlgo &fwd, const CuDNNAlgo &bwd, const CuDNNAlgo &flt) { - ParamKey key{param, in_shape, out_shape, cudnn_data_type, cudnn_forward_compute_type, - cudnn_backward_compute_type, sm_arch}; + CHECK(in_shape.size() == 2 || in_shape.size() == 3); + ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type, + cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch}; std::lock_guard guard(lock_); if (reg_.size() % 50 == 0) { LOG(INFO) << "Running performance tests to find the best convolution " @@ -128,8 +130,7 @@ class CuDNNAlgoReg { struct ParamKey { ParamType param; - std::vector in_shape; - std::vector out_shape; + TShape data_shape, weight_shape, out_shape; cudnnDataType_t cudnn_data_type; cudnnDataType_t cudnn_forward_compute_type; cudnnDataType_t cudnn_backward_compute_type; @@ -137,7 +138,8 @@ class CuDNNAlgoReg { bool operator==(const ParamKey& other) const { return this->param == other.param && - this->in_shape == other.in_shape && + this->data_shape == other.data_shape && + this->weight_shape == other.weight_shape && this->out_shape == other.out_shape && this->cudnn_data_type == other.cudnn_data_type && this->cudnn_forward_compute_type == other.cudnn_forward_compute_type && @@ -150,7 +152,9 @@ class CuDNNAlgoReg { size_t operator()(const ParamKey& key) const { std::hash hash_param; size_t ret = hash_param(key.param); - for (const auto& i : key.in_shape) ret = dmlc::HashCombine(ret, i); + ret = dmlc::HashCombine(ret, key.data_shape); + ret = dmlc::HashCombine(ret, key.weight_shape); + ret = dmlc::HashCombine(ret, key.out_shape); for (const auto& i : key.out_shape) ret = dmlc::HashCombine(ret, i); ret = dmlc::HashCombine(ret, static_cast(key.cudnn_data_type)); ret = dmlc::HashCombine(ret, static_cast(key.cudnn_forward_compute_type)); From 23bae5ae2106e377f59625b79a1d65f0c1f355e1 Mon Sep 17 00:00:00 2001 From: piiswrong Date: Thu, 24 Aug 2017 00:45:05 +0000 Subject: [PATCH 4/4] refactor save_inputs --- src/c_api/c_api_common.h | 2 + src/c_api/c_api_function.cc | 4 +- src/c_api/c_api_ndarray.cc | 56 ++++++++++++------ src/ndarray/autograd.cc | 110 ++++++++++++++++++++---------------- src/ndarray/autograd.h | 30 +++++----- src/ndarray/ndarray.cc | 6 +- 6 files changed, 121 insertions(+), 87 deletions(-) diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index fee3f03f6db0..1ef385609239 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -84,6 +84,8 @@ struct MXAPIThreadLocalEntry { std::vector arg_shape_data, out_shape_data, aux_shape_data; /*! \brief uint32_t buffer for returning shape pointer */ std::vector arg_shape_buffer, out_shape_buffer, aux_shape_buffer; + /*! \brief bool buffer */ + std::vector save_inputs, save_outputs; // helper function to setup return value of shape array inline static void SetupShapeArrayReturnWithBuffer( const std::vector &shapes, diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index 3d8b5328c1a0..259c1331c7af 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -188,8 +188,8 @@ int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs, attrs.parsed = params; // TODO(piiswrong): remove state by using FComputeEx auto state = OpStatePtr::Create(params); - AutogradRuntime::Get()->RecordImperativeOperator( - state, attrs.op, attrs, &ndinputs, &ndoutputs); + AutogradRuntime::Get()->RecordOp( + std::move(attrs), &ndinputs, &ndoutputs, state); for (size_t i = 0; i < ndoutputs.size(); ++i) { *reinterpret_cast(outputs[i]) = ndoutputs[i]; diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index d392baf45d3e..64fa74d8b8c3 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -484,9 +484,11 @@ void PushOperator(const OpStatePtr& state, } void ImperativeInvokeImpl(const Context& default_ctx, - const nnvm::NodeAttrs& attrs, + nnvm::NodeAttrs&& attrs, std::vector* p_ndinputs, - std::vector* p_ndoutputs) { + std::vector* p_ndoutputs, + std::vector* p_save_inputs = nullptr, + std::vector* p_save_outputs = nullptr) { static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); @@ -514,29 +516,32 @@ void ImperativeInvokeImpl(const Context& default_ctx, FCompute fn = common::GetFCompute(op, "FCompute", ctx); FComputeEx fn_ex = common::GetFCompute(op, "FComputeEx", ctx); if (fn_ex && stype != kDefaultStorage) { - if (AutogradRuntime::Get()->IsRecording()) { - AutogradRuntime::Get()->RecordImperativeFCompute(op, - attrs, &ndinputs, &ndoutputs); - } PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs); - } else if (fn) { if (AutogradRuntime::Get()->IsRecording()) { - AutogradRuntime::Get()->RecordImperativeFCompute(op, - attrs, &ndinputs, &ndoutputs); + AutogradRuntime::Get()->RecordOp( + std::move(attrs), &ndinputs, &ndoutputs, OpStatePtr(), + p_save_inputs, p_save_outputs); } + } else if (fn) { PushFCompute(fn, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs, mutate_idx); + if (AutogradRuntime::Get()->IsRecording()) { + AutogradRuntime::Get()->RecordOp( + std::move(attrs), &ndinputs, &ndoutputs, OpStatePtr(), + p_save_inputs, p_save_outputs); + } } else if (createop.count(op)) { auto state = createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types); - if (AutogradRuntime::Get()->IsRecording()) { - AutogradRuntime::Get()->RecordImperativeOperator(state, op, - attrs, &ndinputs, &ndoutputs); - } write_vars.push_back(state.get_var()); PushOperator(state, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs, mutate_idx); + if (AutogradRuntime::Get()->IsRecording()) { + AutogradRuntime::Get()->RecordOp( + std::move(attrs), &ndinputs, &ndoutputs, state, + p_save_inputs, p_save_outputs); + } } else { LOG(FATAL) << "Operator " << op->name << " is not implemented for " @@ -569,7 +574,7 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs, num_outputs, infered_num_outputs, num_visible_outputs, outarray); - ImperativeInvokeImpl(Context::CPU(), attrs, &ndinputs, &ndoutputs); + ImperativeInvokeImpl(Context::CPU(), std::move(attrs), &ndinputs, &ndoutputs); if (outarray == nullptr) { ret->ret_handles.clear(); @@ -618,6 +623,20 @@ int MXCreateCachedOp(SymbolHandle handle, auto vars = sym->ListInputs(nnvm::Symbol::kAll); CHECK_GE(vars.size(), 1) << "CachedOp must have at least 1 input."; g->attrs["vars"] = std::make_shared(std::move(vars)); + + const nnvm::IndexedGraph& idx = g->indexed_graph(); + std::vector > save_inputs(idx.num_nodes()); + std::vector > save_outputs(idx.num_nodes()); + for (size_t i = 0; i < idx.num_nodes(); ++i) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs = idx[i].source->attrs; + AutogradRuntime::Get()->GetBackwardDependency( + node, idx[i].source->num_inputs(), idx[i].source->num_outputs(), + &save_inputs[i], &save_outputs[i]); + } + g->attrs["save_inputs"] = std::make_shared(std::move(save_inputs)); + g->attrs["save_outputs"] = std::make_shared(std::move(save_outputs)); + *out = g; API_END(); } @@ -640,7 +659,11 @@ int MXInvokeCachedOp(CachedOpHandle handle, API_BEGIN(); const std::vector& vars = - g->GetAttr >("vars"); + g->GetAttr >("vars"); + std::vector > save_inputs = + g->GetAttr > >("save_inputs"); + std::vector > save_outputs = + g->GetAttr > >("save_outputs"); const nnvm::IndexedGraph& idx = g->indexed_graph(); CHECK_EQ(static_cast(num_inputs), vars.size()) << "Actually number of inputs differs from expected number of inputs"; @@ -661,7 +684,8 @@ int MXInvokeCachedOp(CachedOpHandle handle, in.emplace_back(buff[idx.entry_id(j)]); } std::vector out(node.source->num_outputs()); - ImperativeInvokeImpl(default_ctx, node.source->attrs, &in, &out); + ImperativeInvokeImpl(default_ctx, nnvm::NodeAttrs(node.source->attrs), &in, &out, + &save_inputs[i], &save_outputs[i]); for (size_t j = 0; j < node.source->num_outputs(); ++j) { buff[idx.entry_id(i, j)] = std::move(out[j]); diff --git a/src/ndarray/autograd.cc b/src/ndarray/autograd.cc index 3fbb3a4924f2..421113f6edd7 100644 --- a/src/ndarray/autograd.cc +++ b/src/ndarray/autograd.cc @@ -29,6 +29,7 @@ #include #include "../executor/graph_executor.h" #include "./autograd.h" +#include "../c_api/c_api_common.h" namespace mxnet { namespace autograd { @@ -101,21 +102,6 @@ void AutogradRuntime::MarkVariables( } } -void AutogradRuntime::RecordImperativeFCompute(const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - std::vector *p_inputs, - std::vector *p_outputs) { - RecordOp(op, attrs, p_inputs, p_outputs, OpStatePtr()); -} - -void AutogradRuntime::RecordImperativeOperator(const OpStatePtr& state, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - std::vector *p_inputs, - std::vector *p_outputs) { - RecordOp(op, attrs, p_inputs, p_outputs, state); -} - std::shared_ptr AutogradRuntime::_GetSharedRef() { static std::shared_ptr inst(new AutogradRuntime()); return inst; @@ -126,12 +112,58 @@ AutogradRuntime* AutogradRuntime::Get() { return ptr; } -void AutogradRuntime::RecordOp(const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, +void AutogradRuntime::GetBackwardDependency(const nnvm::NodePtr& node, + uint32_t num_inputs, uint32_t num_outputs, + std::vector *p_save_inputs, + std::vector *p_save_outputs) { + static auto& fgradient = nnvm::Op::GetAttr("FGradient"); + std::vector& save_inputs = *p_save_inputs; + std::vector& save_outputs = *p_save_outputs; + save_inputs.resize(num_inputs); + save_outputs.resize(num_outputs); + std::fill(save_inputs.begin(), save_inputs.end(), false); + std::fill(save_outputs.begin(), save_outputs.end(), false); + + node->inputs.clear(); + node->inputs.reserve(num_inputs); + for (uint32_t i = 0; i < num_inputs; ++i) { + node->inputs.emplace_back(NodeEntry{nullptr, i, 0}); + } + + if (fgradient.count(node->op())) { + std::vector ograd_entries; + ograd_entries.reserve(num_outputs); + for (uint32_t i = 0; i < num_outputs; ++i) { + ograd_entries.emplace_back(NodeEntry{nullptr, i, 1}); + } + auto igrad_entries = fgradient[node->op()](node, ograd_entries); + for (const auto& i : igrad_entries) { + if (i.node == nullptr && i.version == 0) { + save_inputs[i.index] = true; + } else if (i.node == node) { + save_outputs[i.index] = true; + } + } + DFSVisit(igrad_entries, [&](const NodePtr& gnode) { + if (!gnode || gnode == node) return; + for (const auto& i : gnode->inputs) { + if (i.node == nullptr && i.version == 0) { + save_inputs[i.index] = true; + } else if (i.node == node) { + save_outputs[i.index] = true; + } + } + }); + } +} + +void AutogradRuntime::RecordOp(nnvm::NodeAttrs&& attrs, std::vector *p_inputs, std::vector *p_outputs, - const OpStatePtr& state) { - static auto& fgradient = nnvm::Op::GetAttr("FGradient"); + const OpStatePtr& state, + std::vector* p_save_inputs, + std::vector* p_save_outputs) { + MXAPIThreadLocalEntry *local_buff = MXAPIThreadLocalStore::Get(); std::vector& inputs = *p_inputs; std::vector& outputs = *p_outputs; @@ -154,39 +186,21 @@ void AutogradRuntime::RecordOp(const nnvm::Op* op, if (!need_grad) return; NodePtr nn_node = Node::Create(); - nn_node->attrs = attrs; + nn_node->attrs = std::move(attrs); nn_node->attrs.name = "node_" + std::to_string(node_count_++); - // Get backward dependency - std::vector save_inputs(inputs.size()), save_outputs(outputs.size()); - for (uint32_t i = 0; i < inputs.size(); ++i) { - nn_node->inputs.emplace_back(NodeEntry{nullptr, i, 0}); - } - if (fgradient.count(attrs.op)) { - std::vector ograd_entries; - for (uint32_t i = 0; i < outputs.size(); ++i) { - ograd_entries.emplace_back(NodeEntry{nullptr, i, 1}); - } - auto igrad_entries = fgradient[nn_node->op()](nn_node, ograd_entries); - for (const auto& i : igrad_entries) { - if (i.node == nullptr && i.version == 0) { - save_inputs[i.index] = true; - } else if (i.node == nn_node) { - save_outputs[i.index] = true; - } - } - DFSVisit(igrad_entries, [&](const NodePtr& node) { - if (!node || node == nn_node) return; - for (const auto& i : node->inputs) { - if (i.node == nullptr && i.version == 0) { - save_inputs[i.index] = true; - } else if (i.node == nn_node) { - save_outputs[i.index] = true; - } - } - }); + if (p_save_inputs == nullptr) { + p_save_inputs = &(local_buff->save_inputs); + p_save_outputs = &(local_buff->save_outputs); + GetBackwardDependency( + nn_node, inputs.size(), outputs.size(), p_save_inputs, p_save_outputs); + } else { + nn_node->inputs.resize(inputs.size()); } + std::vector& save_inputs = *p_save_inputs; + std::vector& save_outputs = *p_save_outputs; + AGNodePtr ag_node = AGNode::Create(nn_node); ag_node->state = state; diff --git a/src/ndarray/autograd.h b/src/ndarray/autograd.h index 199af350bf93..4632bc00ebf5 100644 --- a/src/ndarray/autograd.h +++ b/src/ndarray/autograd.h @@ -95,17 +95,19 @@ class AutogradRuntime { void MarkVariables(const std::vector& variables, const std::vector& grad_reqs, const std::vector& gradients); - /*! \brief record imperative operator which is executed by fcompute. */ - void RecordImperativeFCompute(const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - std::vector* p_inputs, - std::vector* p_outputs); - /*! \brief record imperative operator which is executed by operator. */ - void RecordImperativeOperator(const OpStatePtr& state, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - std::vector* p_inputs, - std::vector* p_outputs); + /*! \brief find the input/output ndarrays that are needed for backward */ + void GetBackwardDependency( + const nnvm::NodePtr& node, + uint32_t num_inputs, uint32_t num_outputs, + std::vector *p_save_inputs, + std::vector *p_save_outputs); + /*! \brief to record operator, return corresponding node. */ + void RecordOp(nnvm::NodeAttrs&& attrs, + std::vector* p_inputs, + std::vector* p_outputs, + const OpStatePtr& state = OpStatePtr(), + std::vector* p_save_inputs = nullptr, + std::vector* p_save_outputs = nullptr); /*! \brief compute the gradient of outputs w.r.t variables. */ void ComputeGradient(const std::vector& outputs, const std::vector& ograds, @@ -126,12 +128,6 @@ class AutogradRuntime { AutogradRuntime(); private: - /*! \brief to record operator, return corresponding node. */ - void RecordOp(const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - std::vector* p_inputs, - std::vector* p_outputs, - const OpStatePtr& state); /*! \brief AutogradRuntime singleton. */ static AutogradRuntime* instance_; /*! \brief indicate whether is training. */ diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 139d97670bec..7b79d1051135 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -75,8 +75,7 @@ NDArray NDArray::Reshape(const TShape &shape) const { std::vector inputs, outputs; inputs.emplace_back(*this); outputs.emplace_back(std::move(ret)); - AutogradRuntime::Get()->RecordImperativeFCompute( - op, attrs, &inputs, &outputs); + AutogradRuntime::Get()->RecordOp(std::move(attrs), &inputs, &outputs); return outputs[0]; } else { CHECK_GE(shape_.Size(), shape.Size()) @@ -115,8 +114,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { std::vector inputs, outputs; inputs.emplace_back(*this); outputs.emplace_back(std::move(ret)); - AutogradRuntime::Get()->RecordImperativeFCompute( - op, attrs, &inputs, &outputs); + AutogradRuntime::Get()->RecordOp(std::move(attrs), &inputs, &outputs); return outputs[0]; } else { return ret;