diff --git a/cpp-package/include/mxnet-cpp/executor.hpp b/cpp-package/include/mxnet-cpp/executor.hpp index c642a96268dd..ebae734f88b8 100644 --- a/cpp-package/include/mxnet-cpp/executor.hpp +++ b/cpp-package/include/mxnet-cpp/executor.hpp @@ -16,13 +16,13 @@ namespace mxnet { namespace cpp { -Executor::Executor(const Symbol &symbol, Context context, - const std::vector &arg_arrays, - const std::vector &grad_arrays, - const std::vector &grad_reqs, - const std::vector &aux_arrays, - const std::map &group_to_ctx, - Executor *shared_exec) { +inline Executor::Executor(const Symbol &symbol, Context context, + const std::vector &arg_arrays, + const std::vector &grad_arrays, + const std::vector &grad_reqs, + const std::vector &aux_arrays, + const std::map &group_to_ctx, + Executor *shared_exec) { this->arg_arrays = arg_arrays; this->grad_arrays = grad_arrays; this->aux_arrays = aux_arrays; @@ -73,14 +73,14 @@ Executor::Executor(const Symbol &symbol, Context context, } } -std::string Executor::DebugStr() { +inline std::string Executor::DebugStr() { const char *output; MXExecutorPrint(handle_, &output); return std::string(output); } -void Executor::UpdateAll(Optimizer *opt, float lr, float wd, - int arg_update_begin, int arg_update_end) { +inline void Executor::UpdateAll(Optimizer *opt, float lr, float wd, + int arg_update_begin, int arg_update_end) { arg_update_end = arg_update_end < 0 ? arg_arrays.size() - 1 : arg_update_end; for (int i = arg_update_begin; i < arg_update_end; ++i) { opt->Update(i, arg_arrays[i], grad_arrays[i], lr, wd); diff --git a/cpp-package/include/mxnet-cpp/io.h b/cpp-package/include/mxnet-cpp/io.h index 41c02f249614..2e96d1e39e73 100644 --- a/cpp-package/include/mxnet-cpp/io.h +++ b/cpp-package/include/mxnet-cpp/io.h @@ -119,7 +119,7 @@ class MXDataIter : public DataIter { DataIterCreator creator_; std::map params_; std::shared_ptr blob_ptr_; - static MXDataIterMap *mxdataiter_map_; + static MXDataIterMap*& mxdataiter_map(); }; } // namespace cpp } // namespace mxnet diff --git a/cpp-package/include/mxnet-cpp/io.hpp b/cpp-package/include/mxnet-cpp/io.hpp index 853a6bafb488..c905ef9ced55 100644 --- a/cpp-package/include/mxnet-cpp/io.hpp +++ b/cpp-package/include/mxnet-cpp/io.hpp @@ -14,46 +14,49 @@ namespace mxnet { namespace cpp { -MXDataIterMap *MXDataIter::mxdataiter_map_ = new MXDataIterMap; +inline MXDataIterMap*& MXDataIter::mxdataiter_map() { + static MXDataIterMap* mxdataiter_map_ = new MXDataIterMap; + return mxdataiter_map_; +} -MXDataIter::MXDataIter(const std::string &mxdataiter_type) { - creator_ = mxdataiter_map_->GetMXDataIterCreator(mxdataiter_type); +inline MXDataIter::MXDataIter(const std::string &mxdataiter_type) { + creator_ = mxdataiter_map()->GetMXDataIterCreator(mxdataiter_type); blob_ptr_ = std::make_shared(nullptr); } -void MXDataIter::BeforeFirst() { +inline void MXDataIter::BeforeFirst() { int r = MXDataIterBeforeFirst(blob_ptr_->handle_); CHECK_EQ(r, 0); } -bool MXDataIter::Next() { +inline bool MXDataIter::Next() { int out; int r = MXDataIterNext(blob_ptr_->handle_, &out); CHECK_EQ(r, 0); return out; } -NDArray MXDataIter::GetData() { +inline NDArray MXDataIter::GetData() { NDArrayHandle handle; int r = MXDataIterGetData(blob_ptr_->handle_, &handle); CHECK_EQ(r, 0); return NDArray(handle); } -NDArray MXDataIter::GetLabel() { +inline NDArray MXDataIter::GetLabel() { NDArrayHandle handle; int r = MXDataIterGetLabel(blob_ptr_->handle_, &handle); CHECK_EQ(r, 0); return NDArray(handle); } -int MXDataIter::GetPadNum() { +inline int MXDataIter::GetPadNum() { int out; int r = MXDataIterGetPadNum(blob_ptr_->handle_, &out); CHECK_EQ(r, 0); return out; } -std::vector MXDataIter::GetIndex() { +inline std::vector MXDataIter::GetIndex() { uint64_t *out_index, out_size; int r = MXDataIterGetIndex(blob_ptr_->handle_, &out_index, &out_size); CHECK_EQ(r, 0); @@ -64,7 +67,7 @@ std::vector MXDataIter::GetIndex() { return ret; } -MXDataIter MXDataIter::CreateDataIter() { +inline MXDataIter MXDataIter::CreateDataIter() { std::vector param_keys; std::vector param_values; diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index a7f8404bed8e..c241cc481c11 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -17,30 +17,30 @@ namespace cpp { class KVStore { public: - explicit inline KVStore(const std::string& name = "local"); - KVStore(const KVStore &) = delete; - // VS 2013 doesn't support default move constructor. - KVStore(KVStore &&); - inline void RunServer(); - inline void Init(int key, const NDArray& val); - inline void Init(const std::vector& keys, const std::vector& vals); - inline void Push(int key, const NDArray& val, int priority = 0); - inline void Push(const std::vector& keys, + static void SetType(const std::string& type); + static void RunServer(); + static void Init(int key, const NDArray& val); + static void Init(const std::vector& keys, const std::vector& vals); + static void Push(int key, const NDArray& val, int priority = 0); + static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); - inline void Pull(int key, NDArray* out, int priority = 0); - inline void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(int key, NDArray* out, int priority = 0); + static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? - inline void SetOptimizer(std::unique_ptr optimizer, bool local = false); - inline std::string GetType() const; - inline int GetRank() const; - inline int GetNumWorkers() const; - inline void Barrier() const; - inline std::string GetRole() const; - ~KVStore() { MXKVStoreFree(handle_); } + static void SetOptimizer(std::unique_ptr optimizer, bool local = false); + static std::string GetType(); + static int GetRank(); + static int GetNumWorkers(); + static void Barrier(); + static std::string GetRole(); private: - KVStoreHandle handle_; - std::unique_ptr optimizer_; + KVStore(); + static KVStoreHandle& get_handle(); + static std::unique_ptr& get_optimizer(); + static KVStore*& get_kvstore(); + static void Controller(int head, const char* body, void* controller_handle); + static void Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_); }; } // namespace cpp diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index f4dd765d2f8b..c84d010de358 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -20,54 +20,56 @@ namespace mxnet { namespace cpp { -namespace private_ { - KVStore *kvstore = nullptr; - - extern "C" - void controller(int head, const char* body, void * controller_handle) { - if (kvstore == nullptr) { - return; +inline void KVStore::Controller(int head, const char* body, void* controller_handle) { + if (head == 0) { + std::map params; + std::istringstream sin(body); + std::string line; + while (getline(sin, line)) { + size_t n = line.find('='); + params.emplace(line.substr(0, n), line.substr(n+1)); } - if (head == 0) { - std::map params; - std::istringstream sin(body); - std::string line; - while (getline(sin, line)) { - size_t n = line.find('='); - params.emplace(line.substr(0, n), line.substr(n+1)); - } - std::unique_ptr opt(OptimizerRegistry::Find(params.at("opt_type"))); - params.erase("opt_type"); - for (const auto& pair : params) { - opt->SetParam(pair.first, pair.second); - } - kvstore->SetOptimizer(std::move(opt), true); + std::unique_ptr opt(OptimizerRegistry::Find(params.at("opt_type"))); + params.erase("opt_type"); + for (const auto& pair : params) { + opt->SetParam(pair.first, pair.second); } + get_kvstore()->SetOptimizer(std::move(opt), true); } -} // namespace private_ +} + +inline KVStoreHandle& KVStore::get_handle() { + static KVStoreHandle handle_ = nullptr; + return handle_; +} -KVStore::KVStore(const std::string& name) { - CHECK_EQ(MXKVStoreCreate(name.c_str(), &handle_), 0); +inline std::unique_ptr& KVStore::get_optimizer() { + static std::unique_ptr optimizer_; + return optimizer_; } -KVStore::KVStore(KVStore &&kv) { - optimizer_ = std::move(kv.optimizer_); - handle_ = kv.handle_; - kv.handle_ = nullptr; +inline KVStore*& KVStore::get_kvstore() { + static KVStore* kvstore_ = new KVStore; + return kvstore_; } -void KVStore::RunServer() { +inline KVStore::KVStore() {}; + +inline void KVStore::SetType(const std::string& type) { + CHECK_EQ(MXKVStoreCreate(type.c_str(), &(get_kvstore()->get_handle())), 0); +} + +inline void KVStore::RunServer() { CHECK_NE(GetRole(), "worker"); - private_::kvstore = this; - CHECK_EQ(MXKVStoreRunServer(handle_, &private_::controller, 0), 0); + CHECK_EQ(MXKVStoreRunServer(get_kvstore()->get_handle(), &Controller, 0), 0); } -void KVStore::Init(int key, const NDArray& val) { +inline void KVStore::Init(int key, const NDArray& val) { NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStoreInit(handle_, 1, &key, &val_handle), 0); + CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0); } -void KVStore::Init(const std::vector& keys, const std::vector& vals) { +inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -75,18 +77,18 @@ void KVStore::Init(const std::vector& keys, const std::vector& val return val.GetHandle(); }); - CHECK_EQ(MXKVStoreInit(handle_, keys.size(), keys.data(), + CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), keys.size(), keys.data(), val_handles.data()), 0); } -void KVStore::Push(int key, const NDArray& val, int priority) { +inline void KVStore::Push(int key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStorePush(handle_, 1, &key, &val_handle, priority), 0); + CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0); } -void KVStore::Push(const std::vector& keys, - const std::vector& vals, - int priority) { +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, + int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -94,16 +96,16 @@ void KVStore::Push(const std::vector& keys, return val.GetHandle(); }); - CHECK_EQ(MXKVStorePush(handle_, keys.size(), keys.data(), + CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), keys.size(), keys.data(), val_handles.data(), priority), 0); } -void KVStore::Pull(int key, NDArray* out, int priority) { +inline void KVStore::Pull(int key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); - CHECK_EQ(MXKVStorePull(handle_, 1, &key, &out_handle, priority), 0); + CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0); } -void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector out_handles(keys.size()); @@ -112,52 +114,48 @@ void KVStore::Pull(const std::vector& keys, std::vector* outs, int return val.GetHandle(); }); - CHECK_EQ(MXKVStorePull(handle_, keys.size(), keys.data(), + CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), keys.size(), keys.data(), out_handles.data(), priority), 0); } -namespace private_ { - extern "C" - void updater(int key, NDArrayHandle recv, NDArrayHandle local, - void* handle_) { - Optimizer *opt = static_cast(handle_); - opt->Update(key, NDArray(local), NDArray(recv)); - } +inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local, + void* handle_) { + Optimizer *opt = static_cast(handle_); + opt->Update(key, NDArray(local), NDArray(recv)); } -void KVStore::SetOptimizer(std::unique_ptr optimizer, bool local) { +inline void KVStore::SetOptimizer(std::unique_ptr optimizer, bool local) { if (local) { - optimizer_ = std::move(optimizer); - CHECK_EQ(MXKVStoreSetUpdater(handle_, &private_::updater, optimizer_.get()), 0); + get_kvstore()->get_optimizer() = std::move(optimizer); + CHECK_EQ(MXKVStoreSetUpdater(get_kvstore()->get_handle(), &Updater, get_kvstore()->get_optimizer().get()), 0); } else { - CHECK_EQ(MXKVStoreSendCommmandToServers(handle_, 0, (*optimizer).Serialize().c_str()), 0); + CHECK_EQ(MXKVStoreSendCommmandToServers(get_kvstore()->get_handle(), 0, (*optimizer).Serialize().c_str()), 0); } } -std::string KVStore::GetType() const { +inline std::string KVStore::GetType() { const char *type; - CHECK_EQ(MXKVStoreGetType(handle_, &type), 0); - // type is managed by handle_, no need to free its memory. + CHECK_EQ(MXKVStoreGetType(get_kvstore()->get_handle(), &type), 0); return type; } -int KVStore::GetRank() const { +inline int KVStore::GetRank() { int rank; - CHECK_EQ(MXKVStoreGetRank(handle_, &rank), 0); + CHECK_EQ(MXKVStoreGetRank(get_kvstore()->get_handle(), &rank), 0); return rank; } -int KVStore::GetNumWorkers() const { +inline int KVStore::GetNumWorkers() { int num_workers; - CHECK_EQ(MXKVStoreGetGroupSize(handle_, &num_workers), 0); + CHECK_EQ(MXKVStoreGetGroupSize(get_kvstore()->get_handle(), &num_workers), 0); return num_workers; } -void KVStore::Barrier() const { - CHECK_EQ(MXKVStoreBarrier(handle_), 0); +inline void KVStore::Barrier() { + CHECK_EQ(MXKVStoreBarrier(get_kvstore()->get_handle()), 0); } -std::string KVStore::GetRole() const { +inline std::string KVStore::GetRole() { int ret; CHECK_EQ(MXKVStoreIsSchedulerNode(&ret), 0); if (ret) { diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index f7b5c3233205..ef52df762aab 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -17,37 +17,37 @@ namespace mxnet { namespace cpp { -NDArray::NDArray() { +inline NDArray::NDArray() { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreateNone(&handle), 0); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const NDArrayHandle &handle) { +inline NDArray::NDArray(const NDArrayHandle &handle) { blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const std::vector &shape, const Context &context, - bool delay_alloc) { +inline NDArray::NDArray(const std::vector &shape, const Context &context, + bool delay_alloc) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.size(), context.GetDeviceType(), context.GetDeviceId(), delay_alloc, &handle), 0); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) { +inline NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), delay_alloc, &handle), 0); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const mx_float *data, size_t size) { +inline NDArray::NDArray(const mx_float *data, size_t size) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreateNone(&handle), 0); MXNDArraySyncCopyFromCPU(handle, data, size); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const mx_float *data, const Shape &shape, - const Context &context) { +inline NDArray::NDArray(const mx_float *data, const Shape &shape, + const Context &context) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), false, &handle), @@ -55,8 +55,8 @@ NDArray::NDArray(const mx_float *data, const Shape &shape, MXNDArraySyncCopyFromCPU(handle, data, shape.Size()); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const std::vector &data, const Shape &shape, - const Context &context) { +inline NDArray::NDArray(const std::vector &data, const Shape &shape, + const Context &context) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), context.GetDeviceId(), false, &handle), @@ -64,125 +64,125 @@ NDArray::NDArray(const std::vector &data, const Shape &shape, MXNDArraySyncCopyFromCPU(handle, data.data(), shape.Size()); blob_ptr_ = std::make_shared(handle); } -NDArray::NDArray(const std::vector &data) { +inline NDArray::NDArray(const std::vector &data) { NDArrayHandle handle; CHECK_EQ(MXNDArrayCreateNone(&handle), 0); MXNDArraySyncCopyFromCPU(handle, data.data(), data.size()); blob_ptr_ = std::make_shared(handle); } -NDArray NDArray::operator+(mx_float scalar) { +inline NDArray NDArray::operator+(mx_float scalar) { NDArray ret; Operator("_plus_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator-(mx_float scalar) { +inline NDArray NDArray::operator-(mx_float scalar) { NDArray ret; Operator("_minus_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator*(mx_float scalar) { +inline NDArray NDArray::operator*(mx_float scalar) { NDArray ret; Operator("_mul_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator/(mx_float scalar) { +inline NDArray NDArray::operator/(mx_float scalar) { NDArray ret; Operator("_div_scalar")(*this, scalar).Invoke(ret); return ret; } -NDArray NDArray::operator+(const NDArray &rhs) { +inline NDArray NDArray::operator+(const NDArray &rhs) { NDArray ret; Operator("_plus")(*this, rhs).Invoke(ret); return ret; } -NDArray NDArray::operator-(const NDArray &rhs) { +inline NDArray NDArray::operator-(const NDArray &rhs) { NDArray ret; Operator("_minus")(*this, rhs).Invoke(ret); return ret; } -NDArray NDArray::operator*(const NDArray &rhs) { +inline NDArray NDArray::operator*(const NDArray &rhs) { NDArray ret; Operator("_mul")(*this, rhs).Invoke(ret); return ret; } -NDArray NDArray::operator/(const NDArray &rhs) { +inline NDArray NDArray::operator/(const NDArray &rhs) { NDArray ret; Operator("_div")(*this, rhs).Invoke(ret); return ret; } -NDArray &NDArray::operator=(mx_float scalar) { +inline NDArray &NDArray::operator=(mx_float scalar) { Operator("_set_value")(scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator+=(mx_float scalar) { +inline NDArray &NDArray::operator+=(mx_float scalar) { Operator("_plus_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator-=(mx_float scalar) { +inline NDArray &NDArray::operator-=(mx_float scalar) { Operator("_minus_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator*=(mx_float scalar) { +inline NDArray &NDArray::operator*=(mx_float scalar) { Operator("_mul_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator/=(mx_float scalar) { +inline NDArray &NDArray::operator/=(mx_float scalar) { Operator("_div_scalar")(*this, scalar).Invoke(*this); return *this; } -NDArray &NDArray::operator+=(const NDArray &rhs) { +inline NDArray &NDArray::operator+=(const NDArray &rhs) { Operator("_plus")(*this, rhs).Invoke(*this); return *this; } -NDArray &NDArray::operator-=(const NDArray &rhs) { +inline NDArray &NDArray::operator-=(const NDArray &rhs) { Operator("_minus")(*this, rhs).Invoke(*this); return *this; } -NDArray &NDArray::operator*=(const NDArray &rhs) { +inline NDArray &NDArray::operator*=(const NDArray &rhs) { Operator("_mul")(*this, rhs).Invoke(*this); return *this; } -NDArray &NDArray::operator/=(const NDArray &rhs) { +inline NDArray &NDArray::operator/=(const NDArray &rhs) { Operator("_div")(*this, rhs).Invoke(*this); return *this; } -NDArray NDArray::ArgmaxChannel() { +inline NDArray NDArray::ArgmaxChannel() { NDArray ret; Operator("argmax_channel")(*this).Invoke(ret); return ret; } -void NDArray::SyncCopyFromCPU(const mx_float *data, size_t size) { +inline void NDArray::SyncCopyFromCPU(const mx_float *data, size_t size) { MXNDArraySyncCopyFromCPU(blob_ptr_->handle_, data, size); } -void NDArray::SyncCopyFromCPU(const std::vector &data) { +inline void NDArray::SyncCopyFromCPU(const std::vector &data) { MXNDArraySyncCopyFromCPU(blob_ptr_->handle_, data.data(), data.size()); } -void NDArray::SyncCopyToCPU(mx_float *data, size_t size) { +inline void NDArray::SyncCopyToCPU(mx_float *data, size_t size) { MXNDArraySyncCopyToCPU(blob_ptr_->handle_, data, size > 0 ? size : Size()); } -void NDArray::SyncCopyToCPU(std::vector *data, size_t size) { +inline void NDArray::SyncCopyToCPU(std::vector *data, size_t size) { size = size > 0 ? size : Size(); data->resize(size); MXNDArraySyncCopyToCPU(blob_ptr_->handle_, data->data(), size); } -NDArray NDArray::Copy(const Context &ctx) const { +inline NDArray NDArray::Copy(const Context &ctx) const { NDArray ret(GetShape(), ctx); Operator("_copyto")(*this).Invoke(ret); return ret; } -NDArray NDArray::CopyTo(NDArray * other) const { +inline NDArray NDArray::CopyTo(NDArray * other) const { Operator("_copyto")(*this).Invoke(*other); return *other; } -NDArray NDArray::Slice(mx_uint begin, mx_uint end) const { +inline NDArray NDArray::Slice(mx_uint begin, mx_uint end) const { NDArrayHandle handle; CHECK_EQ(MXNDArraySlice(GetHandle(), begin, end, &handle), 0); return NDArray(handle); } -NDArray NDArray::Reshape(const Shape &new_shape) const { +inline NDArray NDArray::Reshape(const Shape &new_shape) const { NDArrayHandle handle; std::vector dims(new_shape.ndim()); for (index_t i = 0; i < new_shape.ndim(); ++i) { @@ -193,22 +193,22 @@ NDArray NDArray::Reshape(const Shape &new_shape) const { MXNDArrayReshape(GetHandle(), new_shape.ndim(), dims.data(), &handle), 0); return NDArray(handle); } -void NDArray::WaitToRead() const { +inline void NDArray::WaitToRead() const { CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0); } -void NDArray::WaitToWrite() { +inline void NDArray::WaitToWrite() { CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0); } -void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); } -void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) { +inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); } +inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) { Operator("_sample_normal")(mu, sigma).Invoke(*out); } -void NDArray::SampleUniform(mx_float begin, mx_float end, NDArray *out) { +inline void NDArray::SampleUniform(mx_float begin, mx_float end, NDArray *out) { Operator("_sample_uniform")(begin, end).Invoke(*out); } -void NDArray::Load(const std::string &file_name, - std::vector *array_list, - std::map *array_map) { +inline void NDArray::Load(const std::string &file_name, + std::vector *array_list, + std::map *array_map) { mx_uint out_size, out_name_size; NDArrayHandle *out_arr; const char **out_names; @@ -227,7 +227,7 @@ void NDArray::Load(const std::string &file_name, } } } -std::map NDArray::LoadToMap( +inline std::map NDArray::LoadToMap( const std::string &file_name) { std::map array_map; mx_uint out_size, out_name_size; @@ -244,7 +244,7 @@ std::map NDArray::LoadToMap( } return array_map; } -std::vector NDArray::LoadToList(const std::string &file_name) { +inline std::vector NDArray::LoadToList(const std::string &file_name) { std::vector array_list; mx_uint out_size, out_name_size; NDArrayHandle *out_arr; @@ -257,8 +257,8 @@ std::vector NDArray::LoadToList(const std::string &file_name) { } return array_list; } -void NDArray::Save(const std::string &file_name, - const std::map &array_map) { +inline void NDArray::Save(const std::string &file_name, + const std::map &array_map) { std::vector args; std::vector keys; for (const auto &t : array_map) { @@ -269,8 +269,8 @@ void NDArray::Save(const std::string &file_name, MXNDArraySave(file_name.c_str(), args.size(), args.data(), keys.data()), 0); } -void NDArray::Save(const std::string &file_name, - const std::vector &array_list) { +inline void NDArray::Save(const std::string &file_name, + const std::vector &array_list) { std::vector args; for (const auto &t : array_list) { args.push_back(t.GetHandle()); @@ -279,30 +279,30 @@ void NDArray::Save(const std::string &file_name, 0); } -size_t NDArray::Offset(size_t h, size_t w) const { +inline size_t NDArray::Offset(size_t h, size_t w) const { return (h * GetShape()[1]) + w; } -size_t NDArray::Offset(size_t c, size_t h, size_t w) const { +inline size_t NDArray::Offset(size_t c, size_t h, size_t w) const { auto const shape = GetShape(); return h * shape[0] * shape[2] + w * shape[0] + c; } -mx_float NDArray::At(size_t h, size_t w) const { +inline mx_float NDArray::At(size_t h, size_t w) const { return GetData()[Offset(h, w)]; } -mx_float NDArray::At(size_t c, size_t h, size_t w) const { +inline mx_float NDArray::At(size_t c, size_t h, size_t w) const { return GetData()[Offset(c, h, w)]; } -size_t NDArray::Size() const { +inline size_t NDArray::Size() const { size_t ret = 1; for (auto &i : GetShape()) ret *= i; return ret; } -std::vector NDArray::GetShape() const { +inline std::vector NDArray::GetShape() const { const mx_uint *out_pdata; mx_uint out_dim; MXNDArrayGetShape(blob_ptr_->handle_, &out_dim, &out_pdata); @@ -313,13 +313,13 @@ std::vector NDArray::GetShape() const { return ret; } -const mx_float *NDArray::GetData() const { +inline const mx_float *NDArray::GetData() const { mx_float *ret; CHECK_NE(GetContext().GetDeviceType(), DeviceType::kGPU); MXNDArrayGetData(blob_ptr_->handle_, &ret); return ret; } -Context NDArray::GetContext() const { +inline Context NDArray::GetContext() const { int out_dev_type; int out_dev_id; MXNDArrayGetContext(blob_ptr_->handle_, &out_dev_type, &out_dev_id); diff --git a/cpp-package/include/mxnet-cpp/operator.h b/cpp-package/include/mxnet-cpp/operator.h index 66aec7fa0eda..153679c4d695 100644 --- a/cpp-package/include/mxnet-cpp/operator.h +++ b/cpp-package/include/mxnet-cpp/operator.h @@ -180,7 +180,7 @@ class Operator { std::vector input_keys; std::vector arg_names_; AtomicSymbolCreator handle_; - static OpMap *op_map_; + static OpMap*& op_map(); }; } // namespace cpp } // namespace mxnet diff --git a/cpp-package/include/mxnet-cpp/operator.hpp b/cpp-package/include/mxnet-cpp/operator.hpp index 3c8d1afe9b5f..84d8564137ff 100644 --- a/cpp-package/include/mxnet-cpp/operator.hpp +++ b/cpp-package/include/mxnet-cpp/operator.hpp @@ -24,20 +24,23 @@ namespace cpp { * like PushInput, which is not allowed in C++ */ template <> -Operator& Operator::SetParam(int pos, const NDArray &value) { +inline Operator& Operator::SetParam(int pos, const NDArray &value) { input_ndarrays.push_back(value.GetHandle()); return *this; } template <> -Operator& Operator::SetParam(int pos, const Symbol &value) { +inline Operator& Operator::SetParam(int pos, const Symbol &value) { input_symbols.push_back(value.GetHandle()); return *this; } -OpMap *Operator::op_map_ = new OpMap(); +inline OpMap*& Operator::op_map() { + static OpMap *op_map_ = new OpMap(); + return op_map_; +} -Operator::Operator(const std::string &operator_name) { - handle_ = op_map_->GetSymbolCreator(operator_name); +inline Operator::Operator(const std::string &operator_name) { + handle_ = op_map()->GetSymbolCreator(operator_name); const char *name; const char *description; mx_uint num_args; @@ -58,7 +61,7 @@ Operator::Operator(const std::string &operator_name) { } } -Symbol Operator::CreateSymbol(const std::string &name) { +inline Symbol Operator::CreateSymbol(const std::string &name) { if (input_keys.size() > 0) { CHECK_EQ(input_keys.size(), input_symbols.size()); } @@ -86,7 +89,7 @@ Symbol Operator::CreateSymbol(const std::string &name) { return Symbol(symbol_handle); } -void Operator::Invoke(std::vector &outputs) { +inline void Operator::Invoke(std::vector &outputs) { if (input_keys.size() > 0) { CHECK_EQ(input_keys.size(), input_ndarrays.size()); } @@ -126,24 +129,24 @@ void Operator::Invoke(std::vector &outputs) { }); } -std::vector Operator::Invoke() { +inline std::vector Operator::Invoke() { std::vector outputs; Invoke(outputs); return outputs; } -void Operator::Invoke(NDArray &output) { +inline void Operator::Invoke(NDArray &output) { std::vector outputs{output}; Invoke(outputs); } -Operator &Operator::SetInput(const std::string &name, Symbol symbol) { +inline Operator &Operator::SetInput(const std::string &name, Symbol symbol) { input_keys.push_back(name.c_str()); input_symbols.push_back(symbol.GetHandle()); return *this; } -Operator &Operator::SetInput(const std::string &name, NDArray ndarray) { +inline Operator &Operator::SetInput(const std::string &name, NDArray ndarray) { input_keys.push_back(name.c_str()); input_ndarrays.push_back(ndarray.GetHandle()); return *this; diff --git a/cpp-package/include/mxnet-cpp/optimizer.h b/cpp-package/include/mxnet-cpp/optimizer.h index 03c0d90c7b97..94d366f02da4 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.h +++ b/cpp-package/include/mxnet-cpp/optimizer.h @@ -81,7 +81,7 @@ class Optimizer { protected: std::map params_; - static OpMap *op_map_; + static OpMap*& op_map(); const std::vector GetParamKeys_() const; const std::vector GetParamValues_() const; }; @@ -93,7 +93,7 @@ class OptimizerRegistry { static Optimizer* Find(const std::string& name); static int __REGISTER__(const std::string& name, OptimizerCreator creator); private: - static std::map cmap_; + static std::map& cmap(); OptimizerRegistry() = delete; ~OptimizerRegistry() = delete; }; diff --git a/cpp-package/include/mxnet-cpp/optimizer.hpp b/cpp-package/include/mxnet-cpp/optimizer.hpp index 94af0ec759a2..2a6c6a67a708 100644 --- a/cpp-package/include/mxnet-cpp/optimizer.hpp +++ b/cpp-package/include/mxnet-cpp/optimizer.hpp @@ -21,23 +21,26 @@ namespace mxnet { namespace cpp { -OpMap* Optimizer::op_map_ = new OpMap(); - -std::map OptimizerRegistry::cmap_; +inline std::map& OptimizerRegistry::cmap() { + static std::map cmap_; + return cmap_; +} -MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer); -MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility +inline OpMap*& Optimizer::op_map() { + static OpMap *op_map_ = new OpMap(); + return op_map_; +} -Optimizer::~Optimizer() {} +inline Optimizer::~Optimizer() {} -void Optimizer::Update(int index, NDArray weight, NDArray grad, mx_float lr, +inline void Optimizer::Update(int index, NDArray weight, NDArray grad, mx_float lr, mx_float wd) { params_["lr"] = std::to_string(lr); params_["wd"] = std::to_string(wd); Update(index, weight, grad); } -std::string Optimizer::Serialize() const { +inline std::string Optimizer::Serialize() const { using ValueType = std::map::value_type; auto params = params_; params.emplace("opt_type", GetType()); @@ -47,49 +50,51 @@ std::string Optimizer::Serialize() const { }).substr(1); } -const std::vector Optimizer::GetParamKeys_() const { +inline const std::vector Optimizer::GetParamKeys_() const { std::vector keys; for (auto& iter : params_) keys.push_back(iter.first.c_str()); return keys; } -const std::vector Optimizer::GetParamValues_() const { +inline const std::vector Optimizer::GetParamValues_() const { std::vector values; for (auto& iter : params_) values.push_back(iter.second.c_str()); return values; } -Optimizer* OptimizerRegistry::Find(const std::string& name) { - auto it = cmap_.find(name); - if (it == cmap_.end()) +inline Optimizer* OptimizerRegistry::Find(const std::string& name) { + MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer); + MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility + auto it = cmap().find(name); + if (it == cmap().end()) return nullptr; return it->second(); } -int OptimizerRegistry::__REGISTER__(const std::string& name, OptimizerCreator creator) { - CHECK_EQ(cmap_.count(name), 0) << name << " already registered"; - cmap_.emplace(name, std::move(creator)); +inline int OptimizerRegistry::__REGISTER__(const std::string& name, OptimizerCreator creator) { + CHECK_EQ(cmap().count(name), 0) << name << " already registered"; + cmap().emplace(name, std::move(creator)); return 0; } -std::string SGDOptimizer::GetType() const { +inline std::string SGDOptimizer::GetType() const { return "sgd"; } -SGDOptimizer::SGDOptimizer() { - update_handle_ = op_map_->GetSymbolCreator("sgd_update"); - mom_update_handle_ = op_map_->GetSymbolCreator("sgd_mom_update"); +inline SGDOptimizer::SGDOptimizer() { + update_handle_ = op_map()->GetSymbolCreator("sgd_update"); + mom_update_handle_ = op_map()->GetSymbolCreator("sgd_mom_update"); } -SGDOptimizer::~SGDOptimizer() { +inline SGDOptimizer::~SGDOptimizer() { for (auto &it : states_) { delete it.second; } } -void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { +inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { if (states_.count(index) == 0) { CreateState_(index, weight); } @@ -118,7 +123,7 @@ void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { } } -void SGDOptimizer::CreateState_(int index, NDArray weight) { +inline void SGDOptimizer::CreateState_(int index, NDArray weight) { if (params_.count("momentum") == 0) { states_[index] = nullptr; } else { diff --git a/cpp-package/include/mxnet-cpp/symbol.h b/cpp-package/include/mxnet-cpp/symbol.h index 63ef9b1a03e3..0911c03e3937 100644 --- a/cpp-package/include/mxnet-cpp/symbol.h +++ b/cpp-package/include/mxnet-cpp/symbol.h @@ -246,7 +246,7 @@ class Symbol { private: std::shared_ptr blob_ptr_; - static OpMap *op_map_; + static OpMap*& op_map(); }; Symbol operator+(mx_float lhs, const Symbol &rhs); Symbol operator-(mx_float lhs, const Symbol &rhs); diff --git a/cpp-package/include/mxnet-cpp/symbol.hpp b/cpp-package/include/mxnet-cpp/symbol.hpp index f79e96a59fc4..b8e020b53584 100644 --- a/cpp-package/include/mxnet-cpp/symbol.hpp +++ b/cpp-package/include/mxnet-cpp/symbol.hpp @@ -20,39 +20,42 @@ namespace mxnet { namespace cpp { -OpMap *Symbol::op_map_ = new OpMap(); -Symbol::Symbol(SymbolHandle handle) { +inline OpMap*& Symbol::op_map() { + static OpMap* op_map_ = new OpMap(); + return op_map_; +} +inline Symbol::Symbol(SymbolHandle handle) { blob_ptr_ = std::make_shared(handle); } -Symbol::Symbol(const char *name) { +inline Symbol::Symbol(const char *name) { SymbolHandle handle; CHECK_EQ(MXSymbolCreateVariable(name, &(handle)), 0); blob_ptr_ = std::make_shared(handle); } -Symbol::Symbol(const std::string &name) : Symbol(name.c_str()) {} -Symbol Symbol::Variable(const std::string &name) { return Symbol(name); } -Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, rhs); } -Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); } -Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); } -Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); } -Symbol Symbol::operator+(mx_float scalar) const { +inline Symbol::Symbol(const std::string &name) : Symbol(name.c_str()) {} +inline Symbol Symbol::Variable(const std::string &name) { return Symbol(name); } +inline Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, rhs); } +inline Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); } +inline Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); } +inline Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); } +inline Symbol Symbol::operator+(mx_float scalar) const { return _PlusScalar(*this, scalar); } -Symbol Symbol::operator-(mx_float scalar) const { +inline Symbol Symbol::operator-(mx_float scalar) const { return _MinusScalar(*this, scalar); } -Symbol Symbol::operator*(mx_float scalar) const { +inline Symbol Symbol::operator*(mx_float scalar) const { return _MulScalar(*this, scalar); } -Symbol Symbol::operator/(mx_float scalar) const { +inline Symbol Symbol::operator/(mx_float scalar) const { return _DivScalar(*this, scalar); } -Symbol Symbol::operator[](int index) { +inline Symbol Symbol::operator[](int index) { SymbolHandle out; MXSymbolGetOutput(GetHandle(), index, &out); return Symbol(out); } -Symbol Symbol::operator[](const std::string &index) { +inline Symbol Symbol::operator[](const std::string &index) { auto outputs = ListOutputs(); for (mx_uint i = 0; i < outputs.size(); ++i) { if (outputs[i] == index) { @@ -62,7 +65,7 @@ Symbol Symbol::operator[](const std::string &index) { LOG(FATAL) << "Cannot find output that matches name " << index; return (*this)[0]; } -Symbol Symbol::Group(const std::vector &symbols) { +inline Symbol Symbol::Group(const std::vector &symbols) { SymbolHandle out; std::vector handle_list; for (const auto &t : symbols) { @@ -71,36 +74,36 @@ Symbol Symbol::Group(const std::vector &symbols) { MXSymbolCreateGroup(handle_list.size(), handle_list.data(), &out); return Symbol(out); } -Symbol Symbol::Load(const std::string &file_name) { +inline Symbol Symbol::Load(const std::string &file_name) { SymbolHandle handle; CHECK_EQ(MXSymbolCreateFromFile(file_name.c_str(), &(handle)), 0); return Symbol(handle); } -Symbol Symbol::LoadJSON(const std::string &json_str) { +inline Symbol Symbol::LoadJSON(const std::string &json_str) { SymbolHandle handle; CHECK_EQ(MXSymbolCreateFromJSON(json_str.c_str(), &(handle)), 0); return Symbol(handle); } -void Symbol::Save(const std::string &file_name) const { +inline void Symbol::Save(const std::string &file_name) const { CHECK_EQ(MXSymbolSaveToFile(GetHandle(), file_name.c_str()), 0); } -std::string Symbol::ToJSON() const { +inline std::string Symbol::ToJSON() const { const char *out_json; CHECK_EQ(MXSymbolSaveToJSON(GetHandle(), &out_json), 0); return std::string(out_json); } -Symbol Symbol::GetInternals() const { +inline Symbol Symbol::GetInternals() const { SymbolHandle handle; CHECK_EQ(MXSymbolGetInternals(GetHandle(), &handle), 0); return Symbol(handle); } -Symbol::Symbol(const std::string &operator_name, const std::string &name, +inline Symbol::Symbol(const std::string &operator_name, const std::string &name, std::vector input_keys, std::vector input_values, std::vector config_keys, std::vector config_values) { SymbolHandle handle; - AtomicSymbolCreator creator = op_map_->GetSymbolCreator(operator_name); + AtomicSymbolCreator creator = op_map()->GetSymbolCreator(operator_name); MXSymbolCreateAtomicSymbol(creator, config_keys.size(), config_keys.data(), config_values.data(), &handle); MXSymbolCompose(handle, operator_name.c_str(), input_keys.size(), @@ -108,13 +111,13 @@ Symbol::Symbol(const std::string &operator_name, const std::string &name, blob_ptr_ = std::make_shared(handle); } -Symbol Symbol::Copy() const { +inline Symbol Symbol::Copy() const { SymbolHandle handle; CHECK_EQ(MXSymbolCopy(GetHandle(), &handle), 0); return Symbol(handle); } -std::vector Symbol::ListArguments() const { +inline std::vector Symbol::ListArguments() const { std::vector ret; mx_uint size; const char **sarr; @@ -124,7 +127,7 @@ std::vector Symbol::ListArguments() const { } return ret; } -std::vector Symbol::ListOutputs() const { +inline std::vector Symbol::ListOutputs() const { std::vector ret; mx_uint size; const char **sarr; @@ -134,7 +137,7 @@ std::vector Symbol::ListOutputs() const { } return ret; } -std::vector Symbol::ListAuxiliaryStates() const { +inline std::vector Symbol::ListAuxiliaryStates() const { std::vector ret; mx_uint size; const char **sarr; @@ -145,7 +148,7 @@ std::vector Symbol::ListAuxiliaryStates() const { return ret; } -void Symbol::InferShape( +inline void Symbol::InferShape( const std::map > &arg_shapes, std::vector > *in_shape, std::vector > *aux_shape, @@ -205,7 +208,7 @@ void Symbol::InferShape( } } -void Symbol::InferExecutorArrays( +inline void Symbol::InferExecutorArrays( const Context &context, std::vector *arg_arrays, std::vector *grad_arrays, std::vector *grad_reqs, std::vector *aux_arrays, @@ -267,7 +270,7 @@ void Symbol::InferExecutorArrays( } } } -void Symbol::InferArgsMap( +inline void Symbol::InferArgsMap( const Context &context, std::map *args_map, const std::map &known_args) const { @@ -297,7 +300,7 @@ void Symbol::InferArgsMap( } } -Executor *Symbol::SimpleBind( +inline Executor *Symbol::SimpleBind( const Context &context, const std::map &args_map, const std::map &arg_grad_store, const std::map &grad_req_type, @@ -315,7 +318,7 @@ Executor *Symbol::SimpleBind( aux_arrays); } -Executor *Symbol::Bind(const Context &context, +inline Executor *Symbol::Bind(const Context &context, const std::vector &arg_arrays, const std::vector &grad_arrays, const std::vector &grad_reqs, @@ -325,12 +328,12 @@ Executor *Symbol::Bind(const Context &context, return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs, aux_arrays, group_to_ctx, shared_exec); } -Symbol operator+(mx_float lhs, const Symbol &rhs) { return rhs + lhs; } -Symbol operator-(mx_float lhs, const Symbol &rhs) { +inline Symbol operator+(mx_float lhs, const Symbol &rhs) { return rhs + lhs; } +inline Symbol operator-(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RMinusScalar(lhs, rhs); } -Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; } -Symbol operator/(mx_float lhs, const Symbol &rhs) { +inline Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; } +inline Symbol operator/(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RDivScalar(lhs, rhs); } } // namespace cpp