From 9a6a3ada27850013179e29b690fb7d8f8bad6937 Mon Sep 17 00:00:00 2001 From: Ming Chuan Date: Thu, 23 Mar 2017 23:07:57 +0800 Subject: [PATCH 1/2] [cpp-package] Fix multiple definition issue When mxnet hpp headers being included in different translation units, link error occurs because there are multiple (duplicate) definition of functions in different translation units. This commit inlines mxnet cpp functions to follow the one definition rule. --- cpp-package/include/mxnet-cpp/executor.hpp | 20 ++-- cpp-package/include/mxnet-cpp/io.h | 2 +- cpp-package/include/mxnet-cpp/io.hpp | 23 ++-- cpp-package/include/mxnet-cpp/kvstore.h | 3 + cpp-package/include/mxnet-cpp/kvstore.hpp | 93 ++++++++------- cpp-package/include/mxnet-cpp/ndarray.hpp | 118 ++++++++++---------- cpp-package/include/mxnet-cpp/operator.h | 2 +- cpp-package/include/mxnet-cpp/operator.hpp | 25 +++-- cpp-package/include/mxnet-cpp/optimizer.h | 4 +- cpp-package/include/mxnet-cpp/optimizer.hpp | 51 +++++---- cpp-package/include/mxnet-cpp/symbol.h | 2 +- cpp-package/include/mxnet-cpp/symbol.hpp | 75 +++++++------ 12 files changed, 216 insertions(+), 202 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/executor.hpp b/cpp-package/include/mxnet-cpp/executor.hpp index c642a96268dd..ebae734f88b8 100644 --- a/cpp-package/include/mxnet-cpp/executor.hpp +++ b/cpp-package/include/mxnet-cpp/executor.hpp @@ -16,13 +16,13 @@ namespace mxnet { namespace cpp { -Executor::Executor(const Symbol &symbol, Context context, - const std::vector &arg_arrays, - const std::vector &grad_arrays, - const std::vector &grad_reqs, - const std::vector &aux_arrays, - const std::map &group_to_ctx, - Executor *shared_exec) { +inline Executor::Executor(const Symbol &symbol, Context context, + const std::vector &arg_arrays, + const std::vector &grad_arrays, + const std::vector &grad_reqs, + const std::vector &aux_arrays, + const std::map &group_to_ctx, + Executor *shared_exec) { this->arg_arrays = arg_arrays; this->grad_arrays = grad_arrays; this->aux_arrays = aux_arrays; @@ -73,14 +73,14 @@ Executor::Executor(const Symbol &symbol, Context context, } } -std::string Executor::DebugStr() { +inline std::string Executor::DebugStr() { const char *output; MXExecutorPrint(handle_, &output); return std::string(output); } -void Executor::UpdateAll(Optimizer *opt, float lr, float wd, - int arg_update_begin, int arg_update_end) { +inline void Executor::UpdateAll(Optimizer *opt, float lr, float wd, + int arg_update_begin, int arg_update_end) { arg_update_end = arg_update_end < 0 ? arg_arrays.size() - 1 : arg_update_end; for (int i = arg_update_begin; i < arg_update_end; ++i) { opt->Update(i, arg_arrays[i], grad_arrays[i], lr, wd); diff --git a/cpp-package/include/mxnet-cpp/io.h b/cpp-package/include/mxnet-cpp/io.h index 41c02f249614..2e96d1e39e73 100644 --- a/cpp-package/include/mxnet-cpp/io.h +++ b/cpp-package/include/mxnet-cpp/io.h @@ -119,7 +119,7 @@ class MXDataIter : public DataIter { DataIterCreator creator_; std::map params_; std::shared_ptr blob_ptr_; - static MXDataIterMap *mxdataiter_map_; + static MXDataIterMap*& mxdataiter_map(); }; } // namespace cpp } // namespace mxnet diff --git a/cpp-package/include/mxnet-cpp/io.hpp b/cpp-package/include/mxnet-cpp/io.hpp index 853a6bafb488..c905ef9ced55 100644 --- a/cpp-package/include/mxnet-cpp/io.hpp +++ b/cpp-package/include/mxnet-cpp/io.hpp @@ -14,46 +14,49 @@ namespace mxnet { namespace cpp { -MXDataIterMap *MXDataIter::mxdataiter_map_ = new MXDataIterMap; +inline MXDataIterMap*& MXDataIter::mxdataiter_map() { + static MXDataIterMap* mxdataiter_map_ = new MXDataIterMap; + return mxdataiter_map_; +} -MXDataIter::MXDataIter(const std::string &mxdataiter_type) { - creator_ = mxdataiter_map_->GetMXDataIterCreator(mxdataiter_type); +inline MXDataIter::MXDataIter(const std::string &mxdataiter_type) { + creator_ = mxdataiter_map()->GetMXDataIterCreator(mxdataiter_type); blob_ptr_ = std::make_shared(nullptr); } -void MXDataIter::BeforeFirst() { +inline void MXDataIter::BeforeFirst() { int r = MXDataIterBeforeFirst(blob_ptr_->handle_); CHECK_EQ(r, 0); } -bool MXDataIter::Next() { +inline bool MXDataIter::Next() { int out; int r = MXDataIterNext(blob_ptr_->handle_, &out); CHECK_EQ(r, 0); return out; } -NDArray MXDataIter::GetData() { +inline NDArray MXDataIter::GetData() { NDArrayHandle handle; int r = MXDataIterGetData(blob_ptr_->handle_, &handle); CHECK_EQ(r, 0); return NDArray(handle); } -NDArray MXDataIter::GetLabel() { +inline NDArray MXDataIter::GetLabel() { NDArrayHandle handle; int r = MXDataIterGetLabel(blob_ptr_->handle_, &handle); CHECK_EQ(r, 0); return NDArray(handle); } -int MXDataIter::GetPadNum() { +inline int MXDataIter::GetPadNum() { int out; int r = MXDataIterGetPadNum(blob_ptr_->handle_, &out); CHECK_EQ(r, 0); return out; } -std::vector MXDataIter::GetIndex() { +inline std::vector MXDataIter::GetIndex() { uint64_t *out_index, out_size; int r = MXDataIterGetIndex(blob_ptr_->handle_, &out_index, &out_size); CHECK_EQ(r, 0); @@ -64,7 +67,7 @@ std::vector MXDataIter::GetIndex() { return ret; } -MXDataIter MXDataIter::CreateDataIter() { +inline MXDataIter MXDataIter::CreateDataIter() { std::vector param_keys; std::vector param_values; diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index a7f8404bed8e..ef2b7de02fef 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -41,6 +41,9 @@ class KVStore { private: KVStoreHandle handle_; std::unique_ptr optimizer_; + static KVStore*& kvstore_ptr(); + static void Controller(int head, const char* body, void* controller_handle); + static void Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_); }; } // namespace cpp diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index f4dd765d2f8b..93bff1571dc4 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -20,54 +20,54 @@ namespace mxnet { namespace cpp { -namespace private_ { - KVStore *kvstore = nullptr; +inline KVStore*& KVStore::kvstore_ptr() { + static KVStore* kvstore_ = nullptr; + return kvstore_; +} - extern "C" - void controller(int head, const char* body, void * controller_handle) { - if (kvstore == nullptr) { - return; +inline void KVStore::Controller(int head, const char* body, void* controller_handle) { + if (kvstore_ptr() == nullptr) { + return; + } + if (head == 0) { + std::map params; + std::istringstream sin(body); + std::string line; + while (getline(sin, line)) { + size_t n = line.find('='); + params.emplace(line.substr(0, n), line.substr(n+1)); } - if (head == 0) { - std::map params; - std::istringstream sin(body); - std::string line; - while (getline(sin, line)) { - size_t n = line.find('='); - params.emplace(line.substr(0, n), line.substr(n+1)); - } - std::unique_ptr opt(OptimizerRegistry::Find(params.at("opt_type"))); - params.erase("opt_type"); - for (const auto& pair : params) { - opt->SetParam(pair.first, pair.second); - } - kvstore->SetOptimizer(std::move(opt), true); + std::unique_ptr opt(OptimizerRegistry::Find(params.at("opt_type"))); + params.erase("opt_type"); + for (const auto& pair : params) { + opt->SetParam(pair.first, pair.second); } + kvstore_ptr()->SetOptimizer(std::move(opt), true); } -} // namespace private_ +} -KVStore::KVStore(const std::string& name) { +inline KVStore::KVStore(const std::string& name) { CHECK_EQ(MXKVStoreCreate(name.c_str(), &handle_), 0); } -KVStore::KVStore(KVStore &&kv) { +inline KVStore::KVStore(KVStore &&kv) { optimizer_ = std::move(kv.optimizer_); handle_ = kv.handle_; kv.handle_ = nullptr; } -void KVStore::RunServer() { +inline void KVStore::RunServer() { CHECK_NE(GetRole(), "worker"); - private_::kvstore = this; - CHECK_EQ(MXKVStoreRunServer(handle_, &private_::controller, 0), 0); + kvstore_ptr() = this; + CHECK_EQ(MXKVStoreRunServer(handle_, &Controller, 0), 0); } -void KVStore::Init(int key, const NDArray& val) { +inline void KVStore::Init(int key, const NDArray& val) { NDArrayHandle val_handle = val.GetHandle(); CHECK_EQ(MXKVStoreInit(handle_, 1, &key, &val_handle), 0); } -void KVStore::Init(const std::vector& keys, const std::vector& vals) { +inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -79,14 +79,14 @@ void KVStore::Init(const std::vector& keys, const std::vector& val val_handles.data()), 0); } -void KVStore::Push(int key, const NDArray& val, int priority) { +inline void KVStore::Push(int key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); CHECK_EQ(MXKVStorePush(handle_, 1, &key, &val_handle, priority), 0); } -void KVStore::Push(const std::vector& keys, - const std::vector& vals, - int priority) { +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, + int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -98,12 +98,12 @@ void KVStore::Push(const std::vector& keys, val_handles.data(), priority), 0); } -void KVStore::Pull(int key, NDArray* out, int priority) { +inline void KVStore::Pull(int key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); CHECK_EQ(MXKVStorePull(handle_, 1, &key, &out_handle, priority), 0); } -void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector out_handles(keys.size()); @@ -116,48 +116,45 @@ void KVStore::Pull(const std::vector& keys, std::vector* outs, int out_handles.data(), priority), 0); } -namespace private_ { - extern "C" - void updater(int key, NDArrayHandle recv, NDArrayHandle local, - void* handle_) { - Optimizer *opt = static_cast(handle_); - opt->Update(key, NDArray(local), NDArray(recv)); - } +inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local, + void* handle_) { + Optimizer *opt = static_cast(handle_); + opt->Update(key, NDArray(local), NDArray(recv)); } -void KVStore::SetOptimizer(std::unique_ptr optimizer, bool local) { +inline void KVStore::SetOptimizer(std::unique_ptr optimizer, bool local) { if (local) { optimizer_ = std::move(optimizer); - CHECK_EQ(MXKVStoreSetUpdater(handle_, &private_::updater, optimizer_.get()), 0); + CHECK_EQ(MXKVStoreSetUpdater(handle_, &Updater, optimizer_.get()), 0); } else { CHECK_EQ(MXKVStoreSendCommmandToServers(handle_, 0, (*optimizer).Serialize().c_str()), 0); } } -std::string KVStore::GetType() const { +inline std::string KVStore::GetType() const { const char *type; CHECK_EQ(MXKVStoreGetType(handle_, &type), 0); // type is managed by handle_, no need to free its memory. return type; } -int KVStore::GetRank() const { +inline int KVStore::GetRank() const { int rank; CHECK_EQ(MXKVStoreGetRank(handle_, &rank), 0); return rank; } -int KVStore::GetNumWorkers() const { +inline int KVStore::GetNumWorkers() const { int num_workers; CHECK_EQ(MXKVStoreGetGroupSize(handle_, &num_workers), 0); return num_workers; } -void KVStore::Barrier() const { +inline void KVStore::Barrier() const { CHECK_EQ(MXKVStoreBarrier(handle_), 0); } -std::string KVStore::GetRole() const { +inline std::string KVStore::GetRole() const { int ret; CHECK_EQ(MXKVStoreIsSchedulerNode(&ret), 0); if (ret) { diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index f7b5c3233205..ef52df762aab 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -17,37 +17,37 @@ namespace mxnet { namespace cpp { -NDArray::NDArray() { +inline NDArray::NDArray() { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreateNone(&handle), 0); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const NDArrayHandle &handle) { +inline NDArray::NDArray(const NDArrayHandle &handle) { blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const std::vector &shape, const Context &context, - bool delay_alloc) { +inline NDArray::NDArray(const std::vector &shape, const Context &context, + bool delay_alloc) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.size(), context.GetDeviceType(), context.GetDeviceId(), delay_alloc, &handle), 0); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) { +inline NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), delay_alloc, &handle), 0); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const mx_float *data, size_t size) { +inline NDArray::NDArray(const mx_float *data, size_t size) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreateNone(&handle), 0); MXNDArraySyncCopyFromCPU(handle, data, size); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const mx_float *data, const Shape &shape, - const Context &context) { +inline NDArray::NDArray(const mx_float *data, const Shape &shape, + const Context &context) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), false, &handle), @@ -55,8 +55,8 @@ NDArray::NDArray(const mx_float *data, const Shape &shape, MXNDArraySyncCopyFromCPU(handle, data, shape.Size()); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const std::vector &data, const Shape &shape, - const Context &context) { +inline NDArray::NDArray(const std::vector &data, const Shape &shape, + const Context &context) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), false, &handle), @@ -64,125 +64,125 @@ NDArray::NDArray(const std::vector &data, const Shape &shape, MXNDArraySyncCopyFromCPU(handle, data.data(), shape.Size()); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const std::vector &data) { +inline NDArray::NDArray(const std::vector &data) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreateNone(&handle), 0); MXNDArraySyncCopyFromCPU(handle, data.data(), data.size()); blob_ptr_ = std::make_shared(handle); } -NDArray NDArray::operator+(mx_float scalar) { +inline NDArray NDArray::operator+(mx_float scalar) { NDArray ret; Operator("_plus_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator-(mx_float scalar) { +inline NDArray NDArray::operator-(mx_float scalar) { NDArray ret; Operator("_minus_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator*(mx_float scalar) { +inline NDArray NDArray::operator*(mx_float scalar) { NDArray ret; Operator("_mul_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator/(mx_float scalar) { +inline NDArray NDArray::operator/(mx_float scalar) { NDArray ret; Operator("_div_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator+(const NDArray &rhs) { +inline NDArray NDArray::operator+(const NDArray &rhs) { NDArray ret; Operator("_plus")(*this, rhs).Invoke(ret); return ret; } -NDArray NDArray::operator-(const NDArray &rhs) { +inline NDArray NDArray::operator-(const NDArray &rhs) { NDArray ret; Operator("_minus")(*this, rhs).Invoke(ret); return ret; } -NDArray NDArray::operator*(const NDArray &rhs) { +inline NDArray NDArray::operator*(const NDArray &rhs) { NDArray ret; Operator("_mul")(*this, rhs).Invoke(ret); return ret; } -NDArray NDArray::operator/(const NDArray &rhs) { +inline NDArray NDArray::operator/(const NDArray &rhs) { NDArray ret; Operator("_div")(*this, rhs).Invoke(ret); return ret; } -NDArray &NDArray::operator=(mx_float scalar) { +inline NDArray &NDArray::operator=(mx_float scalar) { Operator("_set_value")(scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator+=(mx_float scalar) { +inline NDArray &NDArray::operator+=(mx_float scalar) { Operator("_plus_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator-=(mx_float scalar) { +inline NDArray &NDArray::operator-=(mx_float scalar) { Operator("_minus_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator*=(mx_float scalar) { +inline NDArray &NDArray::operator*=(mx_float scalar) { Operator("_mul_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator/=(mx_float scalar) { +inline NDArray &NDArray::operator/=(mx_float scalar) { Operator("_div_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator+=(const NDArray &rhs) { +inline NDArray &NDArray::operator+=(const NDArray &rhs) { Operator("_plus")(*this, rhs).Invoke(*this); return *this; } -NDArray &NDArray::operator-=(const NDArray &rhs) { +inline NDArray &NDArray::operator-=(const NDArray &rhs) { Operator("_minus")(*this, rhs).Invoke(*this); return *this; } -NDArray &NDArray::operator*=(const NDArray &rhs) { +inline NDArray &NDArray::operator*=(const NDArray &rhs) { Operator("_mul")(*this, rhs).Invoke(*this); return *this; } -NDArray &NDArray::operator/=(const NDArray &rhs) { +inline NDArray &NDArray::operator/=(const NDArray &rhs) { Operator("_div")(*this, rhs).Invoke(*this); return *this; } -NDArray NDArray::ArgmaxChannel() { +inline NDArray NDArray::ArgmaxChannel() { NDArray ret; Operator("argmax_channel")(*this).Invoke(ret); return ret; } -void NDArray::SyncCopyFromCPU(const mx_float *data, size_t size) { +inline void NDArray::SyncCopyFromCPU(const mx_float *data, size_t size) { MXNDArraySyncCopyFromCPU(blob_ptr_->handle_, data, size); } -void NDArray::SyncCopyFromCPU(const std::vector &data) { +inline void NDArray::SyncCopyFromCPU(const std::vector &data) { MXNDArraySyncCopyFromCPU(blob_ptr_->handle_, data.data(), data.size()); } -void NDArray::SyncCopyToCPU(mx_float *data, size_t size) { +inline void NDArray::SyncCopyToCPU(mx_float *data, size_t size) { MXNDArraySyncCopyToCPU(blob_ptr_->handle_, data, size > 0 ? size : Size()); } -void NDArray::SyncCopyToCPU(std::vector *data, size_t size) { +inline void NDArray::SyncCopyToCPU(std::vector *data, size_t size) { size = size > 0 ? size : Size(); data->resize(size); MXNDArraySyncCopyToCPU(blob_ptr_->handle_, data->data(), size); } -NDArray NDArray::Copy(const Context &ctx) const { +inline NDArray NDArray::Copy(const Context &ctx) const { NDArray ret(GetShape(), ctx); Operator("_copyto")(*this).Invoke(ret); return ret; } -NDArray NDArray::CopyTo(NDArray * other) const { +inline NDArray NDArray::CopyTo(NDArray * other) const { Operator("_copyto")(*this).Invoke(*other); return *other; } -NDArray NDArray::Slice(mx_uint begin, mx_uint end) const { +inline NDArray NDArray::Slice(mx_uint begin, mx_uint end) const { NDArrayHandle handle; CHECK_EQ(MXNDArraySlice(GetHandle(), begin, end, &handle), 0); return NDArray(handle); } -NDArray NDArray::Reshape(const Shape &new_shape) const { +inline NDArray NDArray::Reshape(const Shape &new_shape) const { NDArrayHandle handle; std::vector dims(new_shape.ndim()); for (index_t i = 0; i < new_shape.ndim(); ++i) { @@ -193,22 +193,22 @@ NDArray NDArray::Reshape(const Shape &new_shape) const { MXNDArrayReshape(GetHandle(), new_shape.ndim(), dims.data(), &handle), 0); return NDArray(handle); } -void NDArray::WaitToRead() const { +inline void NDArray::WaitToRead() const { CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0); } -void NDArray::WaitToWrite() { +inline void NDArray::WaitToWrite() { CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0); } -void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); } -void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) { +inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); } +inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) { Operator("_sample_normal")(mu, sigma).Invoke(*out); } -void NDArray::SampleUniform(mx_float begin, mx_float end, NDArray *out) { +inline void NDArray::SampleUniform(mx_float begin, mx_float end, NDArray *out) { Operator("_sample_uniform")(begin, end).Invoke(*out); } -void NDArray::Load(const std::string &file_name, - std::vector *array_list, - std::map *array_map) { +inline void NDArray::Load(const std::string &file_name, + std::vector *array_list, + std::map *array_map) { mx_uint out_size, out_name_size; NDArrayHandle *out_arr; const char **out_names; @@ -227,7 +227,7 @@ void NDArray::Load(const std::string &file_name, } } } -std::map NDArray::LoadToMap( +inline std::map NDArray::LoadToMap( const std::string &file_name) { std::map array_map; mx_uint out_size, out_name_size; @@ -244,7 +244,7 @@ std::map NDArray::LoadToMap( } return array_map; } -std::vector NDArray::LoadToList(const std::string &file_name) { +inline std::vector NDArray::LoadToList(const std::string &file_name) { std::vector array_list; mx_uint out_size, out_name_size; NDArrayHandle *out_arr; @@ -257,8 +257,8 @@ std::vector NDArray::LoadToList(const std::string &file_name) { } return array_list; } -void NDArray::Save(const std::string &file_name, - const std::map &array_map) { +inline void NDArray::Save(const std::string &file_name, + const std::map &array_map) { std::vector args; std::vector keys; for (const auto &t : array_map) { @@ -269,8 +269,8 @@ void NDArray::Save(const std::string &file_name, MXNDArraySave(file_name.c_str(), args.size(), args.data(), keys.data()), 0); } -void NDArray::Save(const std::string &file_name, - const std::vector &array_list) { +inline void NDArray::Save(const std::string &file_name, + const std::vector &array_list) { std::vector args; for (const auto &t : array_list) { args.push_back(t.GetHandle()); @@ -279,30 +279,30 @@ void NDArray::Save(const std::string &file_name, 0); } -size_t NDArray::Offset(size_t h, size_t w) const { +inline size_t NDArray::Offset(size_t h, size_t w) const { return (h * GetShape()[1]) + w; } -size_t NDArray::Offset(size_t c, size_t h, size_t w) const { +inline size_t NDArray::Offset(size_t c, size_t h, size_t w) const { auto const shape = GetShape(); return h * shape[0] * shape[2] + w * shape[0] + c; } -mx_float NDArray::At(size_t h, size_t w) const { +inline mx_float NDArray::At(size_t h, size_t w) const { return GetData()[Offset(h, w)]; } -mx_float NDArray::At(size_t c, size_t h, size_t w) const { +inline mx_float NDArray::At(size_t c, size_t h, size_t w) const { return GetData()[Offset(c, h, w)]; } -size_t NDArray::Size() const { +inline size_t NDArray::Size() const { size_t ret = 1; for (auto &i : GetShape()) ret *= i; return ret; } -std::vector NDArray::GetShape() const { +inline std::vector NDArray::GetShape() const { const mx_uint *out_pdata; mx_uint out_dim; MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata); @@ -313,13 +313,13 @@ std::vector NDArray::GetShape() const { return ret; } -const mx_float *NDArray::GetData() const { +inline const mx_float *NDArray::GetData() const { mx_float *ret; CHECK_NE(GetContext().GetDeviceType(), DeviceType::kGPU); MXNDArrayGetData(blob_ptr_->handle_, &ret); return ret; } -Context NDArray::GetContext() const { +inline Context NDArray::GetContext() const { int out_dev_type; int out_dev_id; MXNDArrayGetContext(blob_ptr_->handle_, &out_dev_type, &out_dev_id); diff --git a/cpp-package/include/mxnet-cpp/operator.h b/cpp-package/include/mxnet-cpp/operator.h index 66aec7fa0eda..153679c4d695 100644 --- a/cpp-package/include/mxnet-cpp/operator.h +++ b/cpp-package/include/mxnet-cpp/operator.h @@ -180,7 +180,7 @@ class Operator { std::vector input_keys; std::vector arg_names_; AtomicSymbolCreator handle_; - static OpMap *op_map_; + static OpMap*& op_map(); }; } // namespace cpp } // namespace mxnet diff --git a/cpp-package/include/mxnet-cpp/operator.hpp b/cpp-package/include/mxnet-cpp/operator.hpp index 3c8d1afe9b5f..84d8564137ff 100644 --- a/cpp-package/include/mxnet-cpp/operator.hpp +++ b/cpp-package/include/mxnet-cpp/operator.hpp @@ -24,20 +24,23 @@ namespace cpp { * like PushInput, which is not allowed in C++ */ template <> -Operator& Operator::SetParam(int pos, const NDArray &value) { +inline Operator& Operator::SetParam(int pos, const NDArray &value) { input_ndarrays.push_back(value.GetHandle()); return *this; } template <> -Operator& Operator::SetParam(int pos, const Symbol &value) { +inline Operator& Operator::SetParam(int pos, const Symbol &value) { input_symbols.push_back(value.GetHandle()); return *this; } -OpMap *Operator::op_map_ = new OpMap(); +inline OpMap*& Operator::op_map() { + static OpMap *op_map_ = new OpMap(); + return op_map_; +} -Operator::Operator(const std::string &operator_name) { - handle_ = op_map_->GetSymbolCreator(operator_name); +inline Operator::Operator(const std::string &operator_name) { + handle_ = op_map()->GetSymbolCreator(operator_name); const char *name; const char *description; mx_uint num_args; @@ -58,7 +61,7 @@ Operator::Operator(const std::string &operator_name) { } } -Symbol Operator::CreateSymbol(const std::string &name) { +inline Symbol Operator::CreateSymbol(const std::string &name) { if (input_keys.size() > 0) { CHECK_EQ(input_keys.size(), input_symbols.size()); } @@ -86,7 +89,7 @@ Symbol Operator::CreateSymbol(const std::string &name) { return Symbol(symbol_handle); } -void Operator::Invoke(std::vector &outputs) { +inline void Operator::Invoke(std::vector &outputs) { if (input_keys.size() > 0) { CHECK_EQ(input_keys.size(), input_ndarrays.size()); } @@ -126,24 +129,24 @@ void Operator::Invoke(std::vector &outputs) { }); } -std::vector Operator::Invoke() { +inline std::vector Operator::Invoke() { std::vector outputs; Invoke(outputs); return outputs; } -void Operator::Invoke(NDArray &output) { +inline void Operator::Invoke(NDArray &output) { std::vector outputs{output}; Invoke(outputs); } -Operator &Operator::SetInput(const std::string &name, Symbol symbol) { +inline Operator &Operator::SetInput(const std::string &name, Symbol symbol) { input_keys.push_back(name.c_str()); input_symbols.push_back(symbol.GetHandle()); return *this; } -Operator &Operator::SetInput(const std::string &name, NDArray ndarray) { +inline Operator &Operator::SetInput(const std::string &name, NDArray ndarray) { input_keys.push_back(name.c_str()); input_ndarrays.push_back(ndarray.GetHandle()); return *this; diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h index 03c0d90c7b97..94d366f02da4 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.h +++ b/cpp-package/include/mxnet-cpp/optimizer.h @@ -81,7 +81,7 @@ class Optimizer { protected: std::map params_; - static OpMap *op_map_; + static OpMap*& op_map(); const std::vector GetParamKeys_() const; const std::vector GetParamValues_() const; }; @@ -93,7 +93,7 @@ class OptimizerRegistry { static Optimizer* Find(const std::string& name); static int __REGISTER__(const std::string& name, OptimizerCreator creator); private: - static std::map cmap_; + static std::map& cmap(); OptimizerRegistry() = delete; ~OptimizerRegistry() = delete; }; diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp index 94af0ec759a2..2a6c6a67a708 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.hpp +++ b/cpp-package/include/mxnet-cpp/optimizer.hpp @@ -21,23 +21,26 @@ namespace mxnet { namespace cpp { -OpMap* Optimizer::op_map_ = new OpMap(); - -std::map OptimizerRegistry::cmap_; +inline std::map& OptimizerRegistry::cmap() { + static std::map cmap_; + return cmap_; +} -MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer); -MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility +inline OpMap*& Optimizer::op_map() { + static OpMap *op_map_ = new OpMap(); + return op_map_; +} -Optimizer::~Optimizer() {} +inline Optimizer::~Optimizer() {} -void Optimizer::Update(int index, NDArray weight, NDArray grad, mx_float lr, +inline void Optimizer::Update(int index, NDArray weight, NDArray grad, mx_float lr, mx_float wd) { params_["lr"] = std::to_string(lr); params_["wd"] = std::to_string(wd); Update(index, weight, grad); } -std::string Optimizer::Serialize() const { +inline std::string Optimizer::Serialize() const { using ValueType = std::map::value_type; auto params = params_; params.emplace("opt_type", GetType()); @@ -47,49 +50,51 @@ std::string Optimizer::Serialize() const { }).substr(1); } -const std::vector Optimizer::GetParamKeys_() const { +inline const std::vector Optimizer::GetParamKeys_() const { std::vector keys; for (auto& iter : params_) keys.push_back(iter.first.c_str()); return keys; } -const std::vector Optimizer::GetParamValues_() const { +inline const std::vector Optimizer::GetParamValues_() const { std::vector values; for (auto& iter : params_) values.push_back(iter.second.c_str()); return values; } -Optimizer* OptimizerRegistry::Find(const std::string& name) { - auto it = cmap_.find(name); - if (it == cmap_.end()) +inline Optimizer* OptimizerRegistry::Find(const std::string& name) { + MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer); + MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility + auto it = cmap().find(name); + if (it == cmap().end()) return nullptr; return it->second(); } -int OptimizerRegistry::__REGISTER__(const std::string& name, OptimizerCreator creator) { - CHECK_EQ(cmap_.count(name), 0) << name << " already registered"; - cmap_.emplace(name, std::move(creator)); +inline int OptimizerRegistry::__REGISTER__(const std::string& name, OptimizerCreator creator) { + CHECK_EQ(cmap().count(name), 0) << name << " already registered"; + cmap().emplace(name, std::move(creator)); return 0; } -std::string SGDOptimizer::GetType() const { +inline std::string SGDOptimizer::GetType() const { return "sgd"; } -SGDOptimizer::SGDOptimizer() { - update_handle_ = op_map_->GetSymbolCreator("sgd_update"); - mom_update_handle_ = op_map_->GetSymbolCreator("sgd_mom_update"); +inline SGDOptimizer::SGDOptimizer() { + update_handle_ = op_map()->GetSymbolCreator("sgd_update"); + mom_update_handle_ = op_map()->GetSymbolCreator("sgd_mom_update"); } -SGDOptimizer::~SGDOptimizer() { +inline SGDOptimizer::~SGDOptimizer() { for (auto &it : states_) { delete it.second; } } -void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { +inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { if (states_.count(index) == 0) { CreateState_(index, weight); } @@ -118,7 +123,7 @@ void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { } } -void SGDOptimizer::CreateState_(int index, NDArray weight) { +inline void SGDOptimizer::CreateState_(int index, NDArray weight) { if (params_.count("momentum") == 0) { states_[index] = nullptr; } else { diff --git a/cpp-package/include/mxnet-cpp/symbol.h b/cpp-package/include/mxnet-cpp/symbol.h index 63ef9b1a03e3..0911c03e3937 100644 --- a/cpp-package/include/mxnet-cpp/symbol.h +++ b/cpp-package/include/mxnet-cpp/symbol.h @@ -246,7 +246,7 @@ class Symbol { private: std::shared_ptr blob_ptr_; - static OpMap *op_map_; + static OpMap*& op_map(); }; Symbol operator+(mx_float lhs, const Symbol &rhs); Symbol operator-(mx_float lhs, const Symbol &rhs); diff --git a/cpp-package/include/mxnet-cpp/symbol.hpp b/cpp-package/include/mxnet-cpp/symbol.hpp index f79e96a59fc4..b8e020b53584 100644 --- a/cpp-package/include/mxnet-cpp/symbol.hpp +++ b/cpp-package/include/mxnet-cpp/symbol.hpp @@ -20,39 +20,42 @@ namespace mxnet { namespace cpp { -OpMap *Symbol::op_map_ = new OpMap(); -Symbol::Symbol(SymbolHandle handle) { +inline OpMap*& Symbol::op_map() { + static OpMap* op_map_ = new OpMap(); + return op_map_; +} +inline Symbol::Symbol(SymbolHandle handle) { blob_ptr_ = std::make_shared(handle); } -Symbol::Symbol(const char *name) { +inline Symbol::Symbol(const char *name) { SymbolHandle handle; CHECK_EQ(MXSymbolCreateVariable(name, &(handle)), 0); blob_ptr_ = std::make_shared(handle); } -Symbol::Symbol(const std::string &name) : Symbol(name.c_str()) {} -Symbol Symbol::Variable(const std::string &name) { return Symbol(name); } -Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, rhs); } -Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); } -Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); } -Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); } -Symbol Symbol::operator+(mx_float scalar) const { +inline Symbol::Symbol(const std::string &name) : Symbol(name.c_str()) {} +inline Symbol Symbol::Variable(const std::string &name) { return Symbol(name); } +inline Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, rhs); } +inline Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); } +inline Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); } +inline Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); } +inline Symbol Symbol::operator+(mx_float scalar) const { return _PlusScalar(*this, scalar); } -Symbol Symbol::operator-(mx_float scalar) const { +inline Symbol Symbol::operator-(mx_float scalar) const { return _MinusScalar(*this, scalar); } -Symbol Symbol::operator*(mx_float scalar) const { +inline Symbol Symbol::operator*(mx_float scalar) const { return _MulScalar(*this, scalar); } -Symbol Symbol::operator/(mx_float scalar) const { +inline Symbol Symbol::operator/(mx_float scalar) const { return _DivScalar(*this, scalar); } -Symbol Symbol::operator[](int index) { +inline Symbol Symbol::operator[](int index) { SymbolHandle out; MXSymbolGetOutput(GetHandle(), index, &out); return Symbol(out); } -Symbol Symbol::operator[](const std::string &index) { +inline Symbol Symbol::operator[](const std::string &index) { auto outputs = ListOutputs(); for (mx_uint i = 0; i < outputs.size(); ++i) { if (outputs[i] == index) { @@ -62,7 +65,7 @@ Symbol Symbol::operator[](const std::string &index) { LOG(FATAL) << "Cannot find output that matches name " << index; return (*this)[0]; } -Symbol Symbol::Group(const std::vector &symbols) { +inline Symbol Symbol::Group(const std::vector &symbols) { SymbolHandle out; std::vector handle_list; for (const auto &t : symbols) { @@ -71,36 +74,36 @@ Symbol Symbol::Group(const std::vector &symbols) { MXSymbolCreateGroup(handle_list.size(), handle_list.data(), &out); return Symbol(out); } -Symbol Symbol::Load(const std::string &file_name) { +inline Symbol Symbol::Load(const std::string &file_name) { SymbolHandle handle; CHECK_EQ(MXSymbolCreateFromFile(file_name.c_str(), &(handle)), 0); return Symbol(handle); } -Symbol Symbol::LoadJSON(const std::string &json_str) { +inline Symbol Symbol::LoadJSON(const std::string &json_str) { SymbolHandle handle; CHECK_EQ(MXSymbolCreateFromJSON(json_str.c_str(), &(handle)), 0); return Symbol(handle); } -void Symbol::Save(const std::string &file_name) const { +inline void Symbol::Save(const std::string &file_name) const { CHECK_EQ(MXSymbolSaveToFile(GetHandle(), file_name.c_str()), 0); } -std::string Symbol::ToJSON() const { +inline std::string Symbol::ToJSON() const { const char *out_json; CHECK_EQ(MXSymbolSaveToJSON(GetHandle(), &out_json), 0); return std::string(out_json); } -Symbol Symbol::GetInternals() const { +inline Symbol Symbol::GetInternals() const { SymbolHandle handle; CHECK_EQ(MXSymbolGetInternals(GetHandle(), &handle), 0); return Symbol(handle); } -Symbol::Symbol(const std::string &operator_name, const std::string &name, +inline Symbol::Symbol(const std::string &operator_name, const std::string &name, std::vector input_keys, std::vector input_values, std::vector config_keys, std::vector config_values) { SymbolHandle handle; - AtomicSymbolCreator creator = op_map_->GetSymbolCreator(operator_name); + AtomicSymbolCreator creator = op_map()->GetSymbolCreator(operator_name); MXSymbolCreateAtomicSymbol(creator, config_keys.size(), config_keys.data(), config_values.data(), &handle); MXSymbolCompose(handle, operator_name.c_str(), input_keys.size(), @@ -108,13 +111,13 @@ Symbol::Symbol(const std::string &operator_name, const std::string &name, blob_ptr_ = std::make_shared(handle); } -Symbol Symbol::Copy() const { +inline Symbol Symbol::Copy() const { SymbolHandle handle; CHECK_EQ(MXSymbolCopy(GetHandle(), &handle), 0); return Symbol(handle); } -std::vector Symbol::ListArguments() const { +inline std::vector Symbol::ListArguments() const { std::vector ret; mx_uint size; const char **sarr; @@ -124,7 +127,7 @@ std::vector Symbol::ListArguments() const { } return ret; } -std::vector Symbol::ListOutputs() const { +inline std::vector Symbol::ListOutputs() const { std::vector ret; mx_uint size; const char **sarr; @@ -134,7 +137,7 @@ std::vector Symbol::ListOutputs() const { } return ret; } -std::vector Symbol::ListAuxiliaryStates() const { +inline std::vector Symbol::ListAuxiliaryStates() const { std::vector ret; mx_uint size; const char **sarr; @@ -145,7 +148,7 @@ std::vector Symbol::ListAuxiliaryStates() const { return ret; } -void Symbol::InferShape( +inline void Symbol::InferShape( const std::map > &arg_shapes, std::vector > *in_shape, std::vector > *aux_shape, @@ -205,7 +208,7 @@ void Symbol::InferShape( } } -void Symbol::InferExecutorArrays( +inline void Symbol::InferExecutorArrays( const Context &context, std::vector *arg_arrays, std::vector *grad_arrays, std::vector *grad_reqs, std::vector *aux_arrays, @@ -267,7 +270,7 @@ void Symbol::InferExecutorArrays( } } } -void Symbol::InferArgsMap( +inline void Symbol::InferArgsMap( const Context &context, std::map *args_map, const std::map &known_args) const { @@ -297,7 +300,7 @@ void Symbol::InferArgsMap( } } -Executor *Symbol::SimpleBind( +inline Executor *Symbol::SimpleBind( const Context &context, const std::map &args_map, const std::map &arg_grad_store, const std::map &grad_req_type, @@ -315,7 +318,7 @@ Executor *Symbol::SimpleBind( aux_arrays); } -Executor *Symbol::Bind(const Context &context, +inline Executor *Symbol::Bind(const Context &context, const std::vector &arg_arrays, const std::vector &grad_arrays, const std::vector &grad_reqs, @@ -325,12 +328,12 @@ Executor *Symbol::Bind(const Context &context, return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs, aux_arrays, group_to_ctx, shared_exec); } -Symbol operator+(mx_float lhs, const Symbol &rhs) { return rhs + lhs; } -Symbol operator-(mx_float lhs, const Symbol &rhs) { +inline Symbol operator+(mx_float lhs, const Symbol &rhs) { return rhs + lhs; } +inline Symbol operator-(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RMinusScalar(lhs, rhs); } -Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; } -Symbol operator/(mx_float lhs, const Symbol &rhs) { +inline Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; } +inline Symbol operator/(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RDivScalar(lhs, rhs); } } // namespace cpp From 336dd7ecf0e24dfe7be0f8e614367e416bec587f Mon Sep 17 00:00:00 2001 From: Ming Chuan Date: Fri, 24 Mar 2017 18:49:38 +0800 Subject: [PATCH 2/2] [cpp-package] Define KVStore as a singleton --- cpp-package/include/mxnet-cpp/kvstore.h | 39 ++++++------ cpp-package/include/mxnet-cpp/kvstore.hpp | 73 ++++++++++++----------- 2 files changed, 55 insertions(+), 57 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index ef2b7de02fef..c241cc481c11 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -17,31 +17,28 @@ namespace cpp { class KVStore { public: - explicit inline KVStore(const std::string& name = "local"); - KVStore(const KVStore &) = delete; - // VS 2013 doesn't support default move constructor. - KVStore(KVStore &&); - inline void RunServer(); - inline void Init(int key, const NDArray& val); - inline void Init(const std::vector& keys, const std::vector& vals); - inline void Push(int key, const NDArray& val, int priority = 0); - inline void Push(const std::vector& keys, + static void SetType(const std::string& type); + static void RunServer(); + static void Init(int key, const NDArray& val); + static void Init(const std::vector& keys, const std::vector& vals); + static void Push(int key, const NDArray& val, int priority = 0); + static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); - inline void Pull(int key, NDArray* out, int priority = 0); - inline void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(int key, NDArray* out, int priority = 0); + static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? - inline void SetOptimizer(std::unique_ptr optimizer, bool local = false); - inline std::string GetType() const; - inline int GetRank() const; - inline int GetNumWorkers() const; - inline void Barrier() const; - inline std::string GetRole() const; - ~KVStore() { MXKVStoreFree(handle_); } + static void SetOptimizer(std::unique_ptr optimizer, bool local = false); + static std::string GetType(); + static int GetRank(); + static int GetNumWorkers(); + static void Barrier(); + static std::string GetRole(); private: - KVStoreHandle handle_; - std::unique_ptr optimizer_; - static KVStore*& kvstore_ptr(); + KVStore(); + static KVStoreHandle& get_handle(); + static std::unique_ptr& get_optimizer(); + static KVStore*& get_kvstore(); static void Controller(int head, const char* body, void* controller_handle); static void Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_); }; diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index 93bff1571dc4..c84d010de358 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -20,15 +20,7 @@ namespace mxnet { namespace cpp { -inline KVStore*& KVStore::kvstore_ptr() { - static KVStore* kvstore_ = nullptr; - return kvstore_; -} - inline void KVStore::Controller(int head, const char* body, void* controller_handle) { - if (kvstore_ptr() == nullptr) { - return; - } if (head == 0) { std::map params; std::istringstream sin(body); @@ -42,29 +34,39 @@ inline void KVStore::Controller(int head, const char* body, void* controller_han for (const auto& pair : params) { opt->SetParam(pair.first, pair.second); } - kvstore_ptr()->SetOptimizer(std::move(opt), true); + get_kvstore()->SetOptimizer(std::move(opt), true); } } -inline KVStore::KVStore(const std::string& name) { - CHECK_EQ(MXKVStoreCreate(name.c_str(), &handle_), 0); +inline KVStoreHandle& KVStore::get_handle() { + static KVStoreHandle handle_ = nullptr; + return handle_; +} + +inline std::unique_ptr& KVStore::get_optimizer() { + static std::unique_ptr optimizer_; + return optimizer_; } -inline KVStore::KVStore(KVStore &&kv) { - optimizer_ = std::move(kv.optimizer_); - handle_ = kv.handle_; - kv.handle_ = nullptr; +inline KVStore*& KVStore::get_kvstore() { + static KVStore* kvstore_ = new KVStore; + return kvstore_; +} + +inline KVStore::KVStore() {}; + +inline void KVStore::SetType(const std::string& type) { + CHECK_EQ(MXKVStoreCreate(type.c_str(), &(get_kvstore()->get_handle())), 0); } inline void KVStore::RunServer() { CHECK_NE(GetRole(), "worker"); - kvstore_ptr() = this; - CHECK_EQ(MXKVStoreRunServer(handle_, &Controller, 0), 0); + CHECK_EQ(MXKVStoreRunServer(get_kvstore()->get_handle(), &Controller, 0), 0); } inline void KVStore::Init(int key, const NDArray& val) { NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStoreInit(handle_, 1, &key, &val_handle), 0); + CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0); } inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { @@ -75,13 +77,13 @@ inline void KVStore::Init(const std::vector& keys, const std::vectorget_handle(), keys.size(), keys.data(), val_handles.data()), 0); } inline void KVStore::Push(int key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStorePush(handle_, 1, &key, &val_handle, priority), 0); + CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0); } inline void KVStore::Push(const std::vector& keys, @@ -94,13 +96,13 @@ inline void KVStore::Push(const std::vector& keys, return val.GetHandle(); }); - CHECK_EQ(MXKVStorePush(handle_, keys.size(), keys.data(), + CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), keys.size(), keys.data(), val_handles.data(), priority), 0); } inline void KVStore::Pull(int key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); - CHECK_EQ(MXKVStorePull(handle_, 1, &key, &out_handle, priority), 0); + CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0); } inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { @@ -112,7 +114,7 @@ inline void KVStore::Pull(const std::vector& keys, std::vector* ou return val.GetHandle(); }); - CHECK_EQ(MXKVStorePull(handle_, keys.size(), keys.data(), + CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), keys.size(), keys.data(), out_handles.data(), priority), 0); } @@ -124,37 +126,36 @@ inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local, inline void KVStore::SetOptimizer(std::unique_ptr optimizer, bool local) { if (local) { - optimizer_ = std::move(optimizer); - CHECK_EQ(MXKVStoreSetUpdater(handle_, &Updater, optimizer_.get()), 0); + get_kvstore()->get_optimizer() = std::move(optimizer); + CHECK_EQ(MXKVStoreSetUpdater(get_kvstore()->get_handle(), &Updater, get_kvstore()->get_optimizer().get()), 0); } else { - CHECK_EQ(MXKVStoreSendCommmandToServers(handle_, 0, (*optimizer).Serialize().c_str()), 0); + CHECK_EQ(MXKVStoreSendCommmandToServers(get_kvstore()->get_handle(), 0, (*optimizer).Serialize().c_str()), 0); } } -inline std::string KVStore::GetType() const { +inline std::string KVStore::GetType() { const char *type; - CHECK_EQ(MXKVStoreGetType(handle_, &type), 0); - // type is managed by handle_, no need to free its memory. + CHECK_EQ(MXKVStoreGetType(get_kvstore()->get_handle(), &type), 0); return type; } -inline int KVStore::GetRank() const { +inline int KVStore::GetRank() { int rank; - CHECK_EQ(MXKVStoreGetRank(handle_, &rank), 0); + CHECK_EQ(MXKVStoreGetRank(get_kvstore()->get_handle(), &rank), 0); return rank; } -inline int KVStore::GetNumWorkers() const { +inline int KVStore::GetNumWorkers() { int num_workers; - CHECK_EQ(MXKVStoreGetGroupSize(handle_, &num_workers), 0); + CHECK_EQ(MXKVStoreGetGroupSize(get_kvstore()->get_handle(), &num_workers), 0); return num_workers; } -inline void KVStore::Barrier() const { - CHECK_EQ(MXKVStoreBarrier(handle_), 0); +inline void KVStore::Barrier() { + CHECK_EQ(MXKVStoreBarrier(get_kvstore()->get_handle()), 0); } -inline std::string KVStore::GetRole() const { +inline std::string KVStore::GetRole() { int ret; CHECK_EQ(MXKVStoreIsSchedulerNode(&ret), 0); if (ret) {