Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/kvstore/kvstore_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class KVStoreNCCL : public KVStoreLocal {
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals);
// nccl kvstore doesn't support sparse ndarray
GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true);

std::vector<const NDArray*> merged_ptrs;
std::vector<NDArray*> local_ptrs;
Expand Down Expand Up @@ -146,10 +147,11 @@ class KVStoreNCCL : public KVStoreLocal {

void PullImpl(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) override {
int priority, bool ignore_sparse) override {
CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False";
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true);
std::vector<NDArray> locals;
bool nccl_called = false;

Expand Down Expand Up @@ -191,17 +193,19 @@ class KVStoreNCCL : public KVStoreLocal {
void GroupKVPairsHelper(const std::vector<int>& keys,
const std::vector<T>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<T>> *grouped_vals) {
std::vector<std::vector<T>> *grouped_vals,
bool ignore_sparse) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const T nd) -> bool {
auto validator = [this](const int key, const T nd, bool ignore_sparse) -> bool {
CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False";
auto stype = ptr(nd)->storage_type();
// valid NDArray
if (stype == kDefaultStorage) return true;
// invalid NDArray, abort
LOG(FATAL) << "NCCL kvstore does not support sparse storage type";
return false;
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse);
}

private:
Expand Down