diff --git a/src/kvstore/kvstore_nccl.h b/src/kvstore/kvstore_nccl.h index 95ee8147a153..485cd9556003 100644 --- a/src/kvstore/kvstore_nccl.h +++ b/src/kvstore/kvstore_nccl.h @@ -91,7 +91,8 @@ class KVStoreNCCL : public KVStoreLocal { int priority) override { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals); + // nccl kvstore doesn't support sparse ndarray + GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true); std::vector merged_ptrs; std::vector local_ptrs; @@ -146,10 +147,11 @@ class KVStoreNCCL : public KVStoreLocal { void PullImpl(const std::vector& keys, const std::vector& values, - int priority) override { + int priority, bool ignore_sparse) override { + CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False"; std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true); std::vector locals; bool nccl_called = false; @@ -191,9 +193,11 @@ class KVStoreNCCL : public KVStoreLocal { void GroupKVPairsHelper(const std::vector& keys, const std::vector& values, std::vector *uniq_keys, - std::vector> *grouped_vals) { + std::vector> *grouped_vals, + bool ignore_sparse) { // check if the storage type of a value is valid - auto validator = [this](const int key, const T nd) -> bool { + auto validator = [this](const int key, const T nd, bool ignore_sparse) -> bool { + CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False"; auto stype = ptr(nd)->storage_type(); // valid NDArray if (stype == kDefaultStorage) return true; @@ -201,7 +205,7 @@ class KVStoreNCCL : public KVStoreLocal { LOG(FATAL) << "NCCL kvstore does not support sparse storage type"; return false; }; - GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); } private: