Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions cpp-package/include/mxnet-cpp/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

namespace mxnet {
namespace cpp {
Executor::Executor(const Symbol &symbol, Context context,
const std::vector<NDArray> &arg_arrays,
const std::vector<NDArray> &grad_arrays,
const std::vector<OpReqType> &grad_reqs,
const std::vector<NDArray> &aux_arrays,
const std::map<std::string, Context> &group_to_ctx,
Executor *shared_exec) {
inline Executor::Executor(const Symbol &symbol, Context context,
const std::vector<NDArray> &arg_arrays,
const std::vector<NDArray> &grad_arrays,
const std::vector<OpReqType> &grad_reqs,
const std::vector<NDArray> &aux_arrays,
const std::map<std::string, Context> &group_to_ctx,
Executor *shared_exec) {
this->arg_arrays = arg_arrays;
this->grad_arrays = grad_arrays;
this->aux_arrays = aux_arrays;
Expand Down Expand Up @@ -73,14 +73,14 @@ Executor::Executor(const Symbol &symbol, Context context,
}
}

std::string Executor::DebugStr() {
inline std::string Executor::DebugStr() {
const char *output;
MXExecutorPrint(handle_, &output);
return std::string(output);
}

void Executor::UpdateAll(Optimizer *opt, float lr, float wd,
int arg_update_begin, int arg_update_end) {
inline void Executor::UpdateAll(Optimizer *opt, float lr, float wd,
int arg_update_begin, int arg_update_end) {
arg_update_end = arg_update_end < 0 ? arg_arrays.size() - 1 : arg_update_end;
for (int i = arg_update_begin; i < arg_update_end; ++i) {
opt->Update(i, arg_arrays[i], grad_arrays[i], lr, wd);
Expand Down
2 changes: 1 addition & 1 deletion cpp-package/include/mxnet-cpp/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class MXDataIter : public DataIter {
DataIterCreator creator_;
std::map<std::string, std::string> params_;
std::shared_ptr<MXDataIterBlob> blob_ptr_;
static MXDataIterMap *mxdataiter_map_;
static MXDataIterMap*& mxdataiter_map();
};
} // namespace cpp
} // namespace mxnet
Expand Down
23 changes: 13 additions & 10 deletions cpp-package/include/mxnet-cpp/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,49 @@
namespace mxnet {
namespace cpp {

MXDataIterMap *MXDataIter::mxdataiter_map_ = new MXDataIterMap;
inline MXDataIterMap*& MXDataIter::mxdataiter_map() {
static MXDataIterMap* mxdataiter_map_ = new MXDataIterMap;
return mxdataiter_map_;
}

MXDataIter::MXDataIter(const std::string &mxdataiter_type) {
creator_ = mxdataiter_map_->GetMXDataIterCreator(mxdataiter_type);
inline MXDataIter::MXDataIter(const std::string &mxdataiter_type) {
creator_ = mxdataiter_map()->GetMXDataIterCreator(mxdataiter_type);
blob_ptr_ = std::make_shared<MXDataIterBlob>(nullptr);
}

void MXDataIter::BeforeFirst() {
inline void MXDataIter::BeforeFirst() {
int r = MXDataIterBeforeFirst(blob_ptr_->handle_);
CHECK_EQ(r, 0);
}

bool MXDataIter::Next() {
inline bool MXDataIter::Next() {
int out;
int r = MXDataIterNext(blob_ptr_->handle_, &out);
CHECK_EQ(r, 0);
return out;
}

NDArray MXDataIter::GetData() {
inline NDArray MXDataIter::GetData() {
NDArrayHandle handle;
int r = MXDataIterGetData(blob_ptr_->handle_, &handle);
CHECK_EQ(r, 0);
return NDArray(handle);
}

NDArray MXDataIter::GetLabel() {
inline NDArray MXDataIter::GetLabel() {
NDArrayHandle handle;
int r = MXDataIterGetLabel(blob_ptr_->handle_, &handle);
CHECK_EQ(r, 0);
return NDArray(handle);
}

int MXDataIter::GetPadNum() {
inline int MXDataIter::GetPadNum() {
int out;
int r = MXDataIterGetPadNum(blob_ptr_->handle_, &out);
CHECK_EQ(r, 0);
return out;
}
std::vector<int> MXDataIter::GetIndex() {
inline std::vector<int> MXDataIter::GetIndex() {
uint64_t *out_index, out_size;
int r = MXDataIterGetIndex(blob_ptr_->handle_, &out_index, &out_size);
CHECK_EQ(r, 0);
Expand All @@ -64,7 +67,7 @@ std::vector<int> MXDataIter::GetIndex() {
return ret;
}

MXDataIter MXDataIter::CreateDataIter() {
inline MXDataIter MXDataIter::CreateDataIter() {
std::vector<const char *> param_keys;
std::vector<const char *> param_values;

Expand Down
40 changes: 20 additions & 20 deletions cpp-package/include/mxnet-cpp/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@ namespace cpp {

class KVStore {
public:
explicit inline KVStore(const std::string& name = "local");
KVStore(const KVStore &) = delete;
// VS 2013 doesn't support default move constructor.
KVStore(KVStore &&);
inline void RunServer();
inline void Init(int key, const NDArray& val);
inline void Init(const std::vector<int>& keys, const std::vector<NDArray>& vals);
inline void Push(int key, const NDArray& val, int priority = 0);
inline void Push(const std::vector<int>& keys,
static void SetType(const std::string& type);
static void RunServer();
static void Init(int key, const NDArray& val);
static void Init(const std::vector<int>& keys, const std::vector<NDArray>& vals);
static void Push(int key, const NDArray& val, int priority = 0);
static void Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals, int priority = 0);
inline void Pull(int key, NDArray* out, int priority = 0);
inline void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority = 0);
static void Pull(int key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority = 0);
// TODO(lx): put lr in optimizer or not?
inline void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local = false);
inline std::string GetType() const;
inline int GetRank() const;
inline int GetNumWorkers() const;
inline void Barrier() const;
inline std::string GetRole() const;
~KVStore() { MXKVStoreFree(handle_); }
static void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local = false);
static std::string GetType();
static int GetRank();
static int GetNumWorkers();
static void Barrier();
static std::string GetRole();

private:
KVStoreHandle handle_;
std::unique_ptr<Optimizer> optimizer_;
KVStore();
static KVStoreHandle& get_handle();
static std::unique_ptr<Optimizer>& get_optimizer();
static KVStore*& get_kvstore();
static void Controller(int head, const char* body, void* controller_handle);
static void Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_);
};

} // namespace cpp
Expand Down
130 changes: 64 additions & 66 deletions cpp-package/include/mxnet-cpp/kvstore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,90 +20,92 @@
namespace mxnet {
namespace cpp {

namespace private_ {
KVStore *kvstore = nullptr;

extern "C"
void controller(int head, const char* body, void * controller_handle) {
if (kvstore == nullptr) {
return;
inline void KVStore::Controller(int head, const char* body, void* controller_handle) {
if (head == 0) {
std::map<std::string, std::string> params;
std::istringstream sin(body);
std::string line;
while (getline(sin, line)) {
size_t n = line.find('=');
params.emplace(line.substr(0, n), line.substr(n+1));
}
if (head == 0) {
std::map<std::string, std::string> params;
std::istringstream sin(body);
std::string line;
while (getline(sin, line)) {
size_t n = line.find('=');
params.emplace(line.substr(0, n), line.substr(n+1));
}
std::unique_ptr<Optimizer> opt(OptimizerRegistry::Find(params.at("opt_type")));
params.erase("opt_type");
for (const auto& pair : params) {
opt->SetParam(pair.first, pair.second);
}
kvstore->SetOptimizer(std::move(opt), true);
std::unique_ptr<Optimizer> opt(OptimizerRegistry::Find(params.at("opt_type")));
params.erase("opt_type");
for (const auto& pair : params) {
opt->SetParam(pair.first, pair.second);
}
get_kvstore()->SetOptimizer(std::move(opt), true);
}
} // namespace private_
}

inline KVStoreHandle& KVStore::get_handle() {
static KVStoreHandle handle_ = nullptr;
return handle_;
}

KVStore::KVStore(const std::string& name) {
CHECK_EQ(MXKVStoreCreate(name.c_str(), &handle_), 0);
inline std::unique_ptr<Optimizer>& KVStore::get_optimizer() {
static std::unique_ptr<Optimizer> optimizer_;
return optimizer_;
}

KVStore::KVStore(KVStore &&kv) {
optimizer_ = std::move(kv.optimizer_);
handle_ = kv.handle_;
kv.handle_ = nullptr;
inline KVStore*& KVStore::get_kvstore() {
static KVStore* kvstore_ = new KVStore;
return kvstore_;
}

void KVStore::RunServer() {
inline KVStore::KVStore() {};

inline void KVStore::SetType(const std::string& type) {
CHECK_EQ(MXKVStoreCreate(type.c_str(), &(get_kvstore()->get_handle())), 0);
}

inline void KVStore::RunServer() {
CHECK_NE(GetRole(), "worker");
private_::kvstore = this;
CHECK_EQ(MXKVStoreRunServer(handle_, &private_::controller, 0), 0);
CHECK_EQ(MXKVStoreRunServer(get_kvstore()->get_handle(), &Controller, 0), 0);
}

void KVStore::Init(int key, const NDArray& val) {
inline void KVStore::Init(int key, const NDArray& val) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStoreInit(handle_, 1, &key, &val_handle), 0);
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0);
}

void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArray>& vals) {
inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStoreInit(handle_, keys.size(), keys.data(),
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), keys.size(), keys.data(),
val_handles.data()), 0);
}

void KVStore::Push(int key, const NDArray& val, int priority) {
inline void KVStore::Push(int key, const NDArray& val, int priority) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePush(handle_, 1, &key, &val_handle, priority), 0);
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0);
}

void KVStore::Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals,
int priority) {
inline void KVStore::Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals,
int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePush(handle_, keys.size(), keys.data(),
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), keys.size(), keys.data(),
val_handles.data(), priority), 0);
}

void KVStore::Pull(int key, NDArray* out, int priority) {
inline void KVStore::Pull(int key, NDArray* out, int priority) {
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePull(handle_, 1, &key, &out_handle, priority), 0);
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0);
}

void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority) {
inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<NDArrayHandle> out_handles(keys.size());
Expand All @@ -112,52 +114,48 @@ void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int
return val.GetHandle();
});

CHECK_EQ(MXKVStorePull(handle_, keys.size(), keys.data(),
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), keys.size(), keys.data(),
out_handles.data(), priority), 0);
}

namespace private_ {
extern "C"
void updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
opt->Update(key, NDArray(local), NDArray(recv));
}
inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
opt->Update(key, NDArray(local), NDArray(recv));
}

void KVStore::SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local) {
inline void KVStore::SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local) {
if (local) {
optimizer_ = std::move(optimizer);
CHECK_EQ(MXKVStoreSetUpdater(handle_, &private_::updater, optimizer_.get()), 0);
get_kvstore()->get_optimizer() = std::move(optimizer);
CHECK_EQ(MXKVStoreSetUpdater(get_kvstore()->get_handle(), &Updater, get_kvstore()->get_optimizer().get()), 0);
} else {
CHECK_EQ(MXKVStoreSendCommmandToServers(handle_, 0, (*optimizer).Serialize().c_str()), 0);
CHECK_EQ(MXKVStoreSendCommmandToServers(get_kvstore()->get_handle(), 0, (*optimizer).Serialize().c_str()), 0);
}
}

std::string KVStore::GetType() const {
inline std::string KVStore::GetType() {
const char *type;
CHECK_EQ(MXKVStoreGetType(handle_, &type), 0);
// type is managed by handle_, no need to free its memory.
CHECK_EQ(MXKVStoreGetType(get_kvstore()->get_handle(), &type), 0);
return type;
}

int KVStore::GetRank() const {
inline int KVStore::GetRank() {
int rank;
CHECK_EQ(MXKVStoreGetRank(handle_, &rank), 0);
CHECK_EQ(MXKVStoreGetRank(get_kvstore()->get_handle(), &rank), 0);
return rank;
}

int KVStore::GetNumWorkers() const {
inline int KVStore::GetNumWorkers() {
int num_workers;
CHECK_EQ(MXKVStoreGetGroupSize(handle_, &num_workers), 0);
CHECK_EQ(MXKVStoreGetGroupSize(get_kvstore()->get_handle(), &num_workers), 0);
return num_workers;
}

void KVStore::Barrier() const {
CHECK_EQ(MXKVStoreBarrier(handle_), 0);
inline void KVStore::Barrier() {
CHECK_EQ(MXKVStoreBarrier(get_kvstore()->get_handle()), 0);
}

std::string KVStore::GetRole() const {
inline std::string KVStore::GetRole() {
int ret;
CHECK_EQ(MXKVStoreIsSchedulerNode(&ret), 0);
if (ret) {
Expand Down
Loading