diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index da2d03d519f4..c007a78aefa3 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -22,19 +22,20 @@ */ #ifndef MXNET_KVSTORE_COMM_H_ #define MXNET_KVSTORE_COMM_H_ +#define NVLINK_SUPPORT 4 #include -#include #include -#include #include -#include -#include +#include #include -#include "mxnet/ndarray.h" -#include "gradient_compression.h" +#include +#include +#include #include "../ndarray/ndarray_function.h" #include "../operator/tensor/sparse_retain-inl.h" #include "./kvstore_utils.h" +#include "gradient_compression.h" +#include "mxnet/ndarray.h" namespace mxnet { namespace kvstore { /** @@ -42,10 +43,8 @@ namespace kvstore { */ class Comm { public: - Comm() { - pinned_ctx_ = Context::CPUPinned(0); - } - virtual ~Comm() { } + Comm() { pinned_ctx_ = Context::CPUPinned(0); } + virtual ~Comm() {} /** * \brief init key with the data shape and storage shape */ @@ -54,33 +53,32 @@ class Comm { /** * \brief returns src[0] + .. + src[src.size()-1] */ - virtual const NDArray& Reduce( - int key, const std::vector& src, int priority) = 0; + virtual const NDArray& Reduce(int key, const std::vector& src, + int priority) = 0; /** * \brief copy from src to dst[i] for every i */ - virtual void Broadcast( - int key, const NDArray& src, - const std::vector dst, int priority) = 0; + virtual void Broadcast(int key, const NDArray& src, + const std::vector dst, int priority) = 0; /** * \brief broadcast src to dst[i] with target row_ids for every i - * \param dst a list of destination row_sparse NDArray and its target row_ids to broadcast, + * \param dst a list of destination row_sparse NDArray and its target row_ids + to broadcast, where the row_ids are expected to be unique and sorted - * \param use_copy if set to true, directly copy src to dst[i] without looking up the + * \param use_copy if set to true, directly copy src to dst[i] without looking + up the provided row_ids */ - virtual void BroadcastRowSparse(int key, const NDArray& src, - const std::vector>& dst, - const bool use_copy, - const int priority) = 0; + virtual void BroadcastRowSparse( + int key, const NDArray& src, + const std::vector>& dst, const bool use_copy, + const int priority) = 0; /** * \brief return a pinned contex */ - Context pinned_ctx() const { - return pinned_ctx_; - } + Context pinned_ctx() const { return pinned_ctx_; } /** * \brief Sets gradient compression parameters to be able to @@ -108,7 +106,7 @@ class CommCPU : public Comm { // TODO(junwu) delete the following data member, now for benchmark only is_serial_push_ = dmlc::GetEnv("MXNET_KVSTORE_SERIAL_PUSH", 0); } - virtual ~CommCPU() { } + virtual ~CommCPU() {} void Init(int key, const NDArrayStorageType stype, const TShape& shape, int type = mshadow::kFloat32) override { @@ -121,7 +119,7 @@ class CommCPU : public Comm { const NDArray& Reduce(int key, const std::vector& src, int priority) override { - auto& buf = merge_buf_[key]; + BufferEntry& buf = merge_buf_[key]; // avoid extra copy for single device, but it may bring problems for // abnormal usage of kvstore if (src.size() == 1) { @@ -140,25 +138,28 @@ class CommCPU : public Comm { reduce[0] = buf.merged; if (buf.copy_buf.empty()) { - buf.copy_buf.resize(src.size()-1); + buf.copy_buf.resize(src.size() - 1); for (size_t j = 0; j < src.size() - 1; ++j) { // allocate NDArray based on storage type - buf.copy_buf[j] = NDArray( - src[0].shape(), pinned_ctx_, false, src[0].dtype()); + buf.copy_buf[j] = + NDArray(src[0].shape(), pinned_ctx_, false, src[0].dtype()); } } for (size_t i = 1; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); - reduce[i] = buf.copy_buf[i-1]; - const_vars[i-1] = reduce[i].var(); + CopyFromTo(src[i], &(buf.copy_buf[i - 1]), priority); + reduce[i] = buf.copy_buf[i - 1]; + const_vars[i - 1] = reduce[i].var(); } Engine::Get()->PushAsync( - [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { - ReduceSumCPU(reduce); - on_complete(); - }, Context::CPU(), const_vars, {reduce[0].var()}, - FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + [reduce, this](RunContext rctx, + Engine::CallbackOnComplete on_complete) { + ReduceSumCPU(reduce); + on_complete(); + }, + Context::CPU(), const_vars, {reduce[0].var()}, + FnProperty::kCPUPrioritized, priority, + PROFILER_MESSAGE("KVStoreReduce")); } else { // buf.merged is a sparse ndarray. @@ -168,8 +169,8 @@ class CommCPU : public Comm { if (buf.copy_buf.empty()) { buf.copy_buf.resize(src.size()); for (size_t j = 0; j < src.size(); ++j) { - buf.copy_buf[j] = NDArray( - src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype()); + buf.copy_buf[j] = NDArray(src[0].storage_type(), src[0].shape(), + pinned_ctx_, true, src[0].dtype()); } } for (size_t i = 0; i < src.size(); ++i) { @@ -178,44 +179,46 @@ class CommCPU : public Comm { const_vars[i] = reduce[i].var(); } NDArray result = buf.merged; - Resource rsc = ResourceManager::Get()->Request(result.ctx(), - ResourceRequest(ResourceRequest::kTempSpace)); + Resource rsc = ResourceManager::Get()->Request( + result.ctx(), ResourceRequest(ResourceRequest::kTempSpace)); Engine::Get()->PushAsync( - [reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { - NDArray out = result; - is_serial_push_? - ReduceSumCPUExSerial(reduce, &out) - : mxnet::ndarray::ElementwiseSum(rctx.get_stream(), rsc, reduce, &out); - on_complete(); - }, Context::CPU(), const_vars, {result.var(), rsc.var}, - FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + [reduce, result, rsc, this](RunContext rctx, + Engine::CallbackOnComplete on_complete) { + NDArray out = result; + is_serial_push_ ? ReduceSumCPUExSerial(reduce, &out) + : mxnet::ndarray::ElementwiseSum( + rctx.get_stream(), rsc, reduce, &out); + on_complete(); + }, + Context::CPU(), const_vars, {result.var(), rsc.var}, + FnProperty::kCPUPrioritized, priority, + PROFILER_MESSAGE("KVStoreReduce")); } return buf.merged; } - void Broadcast(int key, const NDArray& src, - const std::vector dst, int priority) override { + void Broadcast(int key, const NDArray& src, const std::vector dst, + int priority) override { int mask = src.ctx().dev_mask(); if (mask == Context::kCPU) { - for (auto d : dst) CopyFromTo(src, d, priority); + for (auto& d : dst) CopyFromTo(src, d, priority); } else { // first copy data to cpu, then broadcast - auto& buf = merge_buf_[key]; + BufferEntry& buf = merge_buf_[key]; CopyFromTo(src, &buf.merged, priority); - for (auto d : dst) CopyFromTo(buf.merged, d, priority); + for (auto& d : dst) CopyFromTo(buf.merged, d, priority); } } void BroadcastRowSparse(int key, const NDArray& src, const std::vector>& dst, - const bool use_copy, - const int priority) override { + const bool use_copy, const int priority) override { using namespace mshadow; CHECK_EQ(src.storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row-sparse src NDArray"; + << "BroadcastRowSparse expects row-sparse src NDArray"; CHECK_EQ(src.ctx().dev_mask(), Context::kCPU) - << "BroadcastRowSparse with src on gpu context not supported"; + << "BroadcastRowSparse with src on gpu context not supported"; for (size_t i = 0; i < dst.size(); ++i) { NDArray* out = dst[i].first; NDArray row_id = dst[i].second; @@ -223,40 +226,47 @@ class CommCPU : public Comm { CopyFromTo(src, out, priority); } else { CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row_sparse dst NDArray"; + << "BroadcastRowSparse expects row_sparse dst NDArray"; CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU) - << "BroadcastRowSparse with row_indices on gpu context not supported"; + << "BroadcastRowSparse with row_indices on gpu context not " + "supported"; // retain according to unique indices - const bool use_sparse_retain = (src.shape()[0] != src.storage_shape()[0]) - || (row_id.dtype() != out->aux_type(rowsparse::kIdx)) - || (out->ctx().dev_mask() != Context::kGPU); + const bool use_sparse_retain = + (src.shape()[0] != src.storage_shape()[0]) || + (row_id.dtype() != out->aux_type(rowsparse::kIdx)) || + (out->ctx().dev_mask() != Context::kGPU); if (use_sparse_retain) { // use sparse_retain op const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU; - NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(), - src.ctx(), true, src.dtype(), src.aux_types()) : *out; + NDArray out_cpu = + is_to_gpu ? NDArray(kRowSparseStorage, src.shape(), src.ctx(), + true, src.dtype(), src.aux_types()) + : *out; Engine::Get()->PushAsync( - [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - const TBlob& indices = row_id.data(); - NDArray temp = out_cpu; // get rid of const qualifier - op::SparseRetainOpForwardRspImpl(rctx.get_stream(), - src, indices, kWriteTo, - &temp); - on_complete(); - }, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()}, - FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain")); + [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + const TBlob& indices = row_id.data(); + NDArray temp = out_cpu; // get rid of const qualifier + op::SparseRetainOpForwardRspImpl( + rctx.get_stream(), src, indices, kWriteTo, &temp); + on_complete(); + }, + Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()}, + FnProperty::kNormal, priority, + PROFILER_MESSAGE("KVStoreSparseRetain")); if (is_to_gpu) { CopyFromTo(out_cpu, out, priority); } } else { // direct copy rows Engine::Get()->PushAsync( - [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - CopyRetainedRowsToGPU(rctx.get_stream(), rctx.get_stream(), - src, row_id, out); - // wait for GPU operations to complete - rctx.get_stream()->Wait(); - on_complete(); - }, out->ctx(), {src.var(), row_id.var()}, {out->var()}, - FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU")); + [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + CopyRetainedRowsToGPU(rctx.get_stream(), + rctx.get_stream(), src, row_id, out); + // wait for GPU operations to complete + rctx.get_stream()->Wait(); + on_complete(); + }, + out->ctx(), {src.var(), row_id.var()}, {out->var()}, + FnProperty::kCopyToGPU, priority, + PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU")); } } } @@ -270,22 +280,22 @@ class CommCPU : public Comm { */ void CopyRetainedRowsToGPU(mshadow::Stream* cpu_stream, mshadow::Stream* gpu_stream, - const NDArray& src, - const NDArray& indices, + const NDArray& src, const NDArray& indices, NDArray* dst) { #if MXNET_USE_CUDA == 1 CHECK_EQ(src.storage_type(), kRowSparseStorage) - << "CopyRetainedRowsToGPU expects row-sparse src NDArray"; + << "CopyRetainedRowsToGPU expects row-sparse src NDArray"; CHECK_EQ(src.ctx().dev_mask(), Context::kCPU) - << "CopyRetainedRowsToGPU with src on gpu context not supported"; + << "CopyRetainedRowsToGPU with src on gpu context not supported"; CHECK_EQ(src.storage_shape()[0], src.shape()[0]) - << "CopyRetainedRowsToGPU only supports src rsp with full rows"; + << "CopyRetainedRowsToGPU only supports src rsp with full rows"; CHECK_EQ(indices.storage_type(), kDefaultStorage); CHECK_EQ(indices.ctx().dev_mask(), Context::kCPU); CHECK_EQ(dst->storage_type(), kRowSparseStorage); CHECK_EQ(dst->ctx().dev_mask(), Context::kGPU); CHECK_EQ(indices.dtype(), dst->aux_type(rowsparse::kIdx)) - << "CopyRetainedRowsToGPU only supports same data type for idx array and dst aux_data(0)"; + << "CopyRetainedRowsToGPU only supports same data type for idx array " + "and dst aux_data(0)"; if (!src.storage_initialized() || indices.data().Size() == 0U) { op::FillZerosRspImpl(gpu_stream, *dst); return; @@ -299,29 +309,33 @@ class CommCPU : public Comm { dst->CheckAndAlloc({Shape1(num_rows_retained)}); TBlob dst_data = dst->data(); TBlob dst_idx_data = dst->aux_data(rowsparse::kIdx); - MSHADOW_TYPE_SWITCH(src.dtype(), DType, { - MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, { - // copy idx array - Tensor dst_idx_tensor = dst_idx_data.FlatTo1D(gpu_stream); - const Tensor idx_tensor = idx_data.FlatTo1D(cpu_stream); - Copy(dst_idx_tensor, idx_tensor, gpu_stream); - // copy src data - const Tensor src_data_tensor = src_data.get_with_shape( - Shape2(src_data.shape_[0], row_length), cpu_stream); - Tensor dst_data_tensor = dst_data.get_with_shape( - Shape2(dst_data.shape_[0], row_length), gpu_stream); - for (size_t i = 0; i < num_rows_retained; ++i) { - Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]], gpu_stream); - } - }) - }) + MSHADOW_TYPE_SWITCH( + src.dtype(), DType, {MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, { + // copy idx array + Tensor dst_idx_tensor = + dst_idx_data.FlatTo1D(gpu_stream); + const Tensor idx_tensor = + idx_data.FlatTo1D(cpu_stream); + Copy(dst_idx_tensor, idx_tensor, gpu_stream); + // copy src data + const Tensor src_data_tensor = + src_data.get_with_shape( + Shape2(src_data.shape_[0], row_length), cpu_stream); + Tensor dst_data_tensor = + dst_data.get_with_shape( + Shape2(dst_data.shape_[0], row_length), gpu_stream); + for (size_t i = 0; i < num_rows_retained; ++i) { + Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]], + gpu_stream); + } + })}) #else LOG(FATAL) << "GPU not enabled"; #endif } // reduce sum into val[0] - inline void ReduceSumCPU(const std::vector &in_data) { + inline void ReduceSumCPU(const std::vector& in_data) { MSHADOW_TYPE_SWITCH(in_data[0].dtype(), DType, { std::vector dptr(in_data.size()); for (size_t i = 0; i < in_data.size(); ++i) { @@ -335,7 +349,8 @@ class CommCPU : public Comm { } // serial implementation of reduce sum for row sparse NDArray. - inline void ReduceSumCPUExSerial(const std::vector &in, NDArray *out) { + inline void ReduceSumCPUExSerial(const std::vector& in, + NDArray* out) { using namespace rowsparse; using namespace mshadow; auto stype = out->storage_type(); @@ -374,7 +389,8 @@ class CommCPU : public Comm { CHECK_EQ(indices.size(), total_num_rows); // dedup indices std::sort(indices.begin(), indices.end()); - indices.resize(std::unique(indices.begin(), indices.end()) - indices.begin()); + indices.resize(std::unique(indices.begin(), indices.end()) - + indices.begin()); // the one left are unique non-zero rows size_t nnr = indices.size(); // allocate memory for output @@ -406,12 +422,12 @@ class CommCPU : public Comm { }); } - template - inline static void ReduceSumCPU( - const std::vector &dptr, size_t offset, index_t size) { + template + inline static void ReduceSumCPU(const std::vector& dptr, + size_t offset, index_t size) { using namespace mshadow; // NOLINT(*) Tensor in_0(dptr[0] + offset, Shape1(size)); - for (size_t i = 1; i < dptr.size(); i+=4) { + for (size_t i = 1; i < dptr.size(); i += 4) { switch (dptr.size() - i) { case 1: { Tensor in_1(dptr[i] + offset, Shape1(size)); @@ -420,22 +436,22 @@ class CommCPU : public Comm { } case 2: { Tensor in_1(dptr[i] + offset, Shape1(size)); - Tensor in_2(dptr[i+1] + offset, Shape1(size)); + Tensor in_2(dptr[i + 1] + offset, Shape1(size)); in_0 += in_1 + in_2; break; } case 3: { Tensor in_1(dptr[i] + offset, Shape1(size)); - Tensor in_2(dptr[i+1] + offset, Shape1(size)); - Tensor in_3(dptr[i+2] + offset, Shape1(size)); + Tensor in_2(dptr[i + 1] + offset, Shape1(size)); + Tensor in_3(dptr[i + 2] + offset, Shape1(size)); in_0 += in_1 + in_2 + in_3; break; } default: { Tensor in_1(dptr[i] + offset, Shape1(size)); - Tensor in_2(dptr[i+1] + offset, Shape1(size)); - Tensor in_3(dptr[i+2] + offset, Shape1(size)); - Tensor in_4(dptr[i+3] + offset, Shape1(size)); + Tensor in_2(dptr[i + 1] + offset, Shape1(size)); + Tensor in_3(dptr[i + 2] + offset, Shape1(size)); + Tensor in_4(dptr[i + 3] + offset, Shape1(size)); in_0 += in_1 + in_2 + in_3 + in_4; break; } @@ -443,15 +459,15 @@ class CommCPU : public Comm { } } - template + template inline void ReduceSumCPUImpl(std::vector dptr, size_t total) { const size_t step = std::min(bigarray_bound_, static_cast(4 << 10)); - long ntask = (total + step - 1) / step; // NOLINT(*) + long ntask = (total + step - 1) / step; // NOLINT(*) if (total < bigarray_bound_ || nthread_reduction_ <= 1) { ReduceSumCPU(dptr, 0, total); } else { - #pragma omp parallel for schedule(static) num_threads(nthread_reduction_) - for (long j = 0; j < ntask; ++j) { // NOLINT(*) +#pragma omp parallel for schedule(static) num_threads(nthread_reduction_) + for (long j = 0; j < ntask; ++j) { // NOLINT(*) size_t k = static_cast(j); size_t begin = std::min(k * step, total); size_t end = std::min((k + 1) * step, total); @@ -484,11 +500,9 @@ class CommCPU : public Comm { */ class CommDevice : public Comm { public: - CommDevice() { - inited_ = false; - } + CommDevice() { inited_ = false; } - virtual ~CommDevice() { } + virtual ~CommDevice() {} void Init(int key, const NDArrayStorageType stype, const TShape& shape, int dtype = mshadow::kFloat32) override { @@ -523,95 +537,236 @@ class CommDevice : public Comm { } InitBuffersAndComm(src); - auto& buf = merge_buf_[key]; - std::vector reduce(src.size()); - - const NDArrayStorageType stype = buf.merged.storage_type(); + // merge buffer holds the first group of gpus + BufferEntry& buf = merge_buf_[key]; + // stage buffer holds the data of the second group or the first when merge + // buffer is empty + BufferEntry& stage = stage_buf_[key]; + std::vector reduce_s; + + const NDArrayStorageType stype = stage.merged.storage_type(); if (stype == kDefaultStorage) { - CopyFromTo(src[0], &(buf.merged), priority); - reduce[0] = buf.merged; - - if (buf.copy_buf.empty()) { - // TODO(mli) this results in large device memory usage for huge ndarray, - // such as the largest fullc in VGG. consider to do segment reduce with - // NDArray.Slice or gpu direct memory access. for the latter, we need to - // remove some ctx check, and also it reduces 20% perf - buf.copy_buf.resize(src.size()-1); - for (size_t i = 0; i < src.size()-1; ++i) { - buf.copy_buf[i] = NDArray( - buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype()); - } + if (buf.merged.is_none() && stage.copy_buf.empty()) { + stage.copy_buf.resize(src.size() - 1); + for (size_t i = 0; i < src.size() - 1; ++i) + stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(), + false, stage.merged.dtype()); } - for (size_t i = 0; i < src.size()-1; ++i) { - CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); - reduce[i+1] = buf.copy_buf[i]; + reduce_s.resize(stage.copy_buf.size() + 1); + for (size_t i = 0, j = 0; i < src.size(); ++i) { + int id = src[i].ctx().dev_id; + if ((!buf.merged.is_none() && id == stage.merged.ctx().dev_id) || + (buf.merged.is_none() && i == 0)) { + CopyFromTo(src[i], &stage.merged, priority); + reduce_s[0] = stage.merged; + } else if (id >= 4 || buf.merged.is_none()) { + CopyFromTo(src[i], &(stage.copy_buf[j]), priority); + reduce_s[j + 1] = stage.copy_buf[j]; + j++; + } } } else { - if (buf.copy_buf.empty()) { - buf.copy_buf.resize(src.size()); - for (size_t j = 0; j < src.size(); ++j) { - buf.copy_buf[j] = NDArray( - buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(), - true, buf.merged.dtype()); + if (buf.merged.is_none() && stage.copy_buf.empty()) { + stage.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) + stage.copy_buf[j] = + NDArray(stage.merged.storage_type(), stage.merged.shape(), + stage.merged.ctx(), true, stage.merged.dtype()); + } + reduce_s.resize(stage.copy_buf.size()); + for (size_t i = 0, j = 0; i < src.size(); ++i) { + int id = src[i].ctx().dev_id; + if (id >= 4 || buf.merged.is_none()) { + CopyFromTo(src[i], &(stage.copy_buf[j]), priority); + reduce_s[j] = stage.copy_buf[j]; + j++; } } - for (size_t i = 0; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i]), priority); - reduce[i] = buf.copy_buf[i]; + } + // Reducing either the second group of data or the second when merge buffer + // is empty + ElementwiseSum(reduce_s, &stage.merged, priority); + // Main reduce result on the first group of GPUs including the partial + // result from the second group + if (!buf.merged.is_none()) { + const NDArrayStorageType sstype = buf.merged.storage_type(); + std::vector reduce; + if (sstype == kDefaultStorage) { + reduce.resize(buf.copy_buf.size() + 1); + for (size_t i = 0, j = 0; i < src.size(); ++i) { + int id = src[i].ctx().dev_id; + if (id == buf.merged.ctx().dev_id) { + reduce[0] = src[i]; + } else if (id < 4) { + CopyFromTo(src[i], &(buf.copy_buf[j]), priority); + reduce[j + 1] = buf.copy_buf[j]; + j++; + } + } + } else { + reduce.resize(buf.copy_buf.size()); + for (size_t i = 0, j = 0; i < src.size(); ++i) { + int id = src[i].ctx().dev_id; + if (id < 4) { + CopyFromTo(src[i], &(buf.copy_buf[j]), priority); + reduce[j] = buf.copy_buf[j]; + j++; + } + } } + // Copy the second group's reducing result to merge buffer + CopyFromTo(stage.merged, &(buf.copy_buf[buf.copy_buf.size() - 1]), + priority); + reduce[reduce.size() - 1] = buf.copy_buf[buf.copy_buf.size() - 1]; + ElementwiseSum(reduce, &buf.merged); + } else { + return stage.merged; } - ElementwiseSum(reduce, &buf.merged, priority); + return buf.merged; } const NDArray& ReduceCompressed(int key, const std::vector& src, int priority) { InitBuffersAndComm(src); - auto& buf = merge_buf_[key]; - std::vector reduce(src.size()); - if (buf.copy_buf.empty()) { + BufferEntry& buf = merge_buf_[key]; + BufferEntry& stage = stage_buf_[key]; + if (buf.merged.is_none() && stage.copy_buf.empty()) { // one buf for each context - buf.copy_buf.resize(src.size()); - buf.compressed_recv_buf.resize(src.size()); - buf.compressed_send_buf.resize(src.size()); - buf.residual.resize(src.size()); + stage.copy_buf.resize(src.size()); + stage.compressed_recv_buf.resize(src.size()); + stage.compressed_send_buf.resize(src.size()); + stage.residual.resize(src.size()); for (size_t i = 0; i < src.size(); ++i) { - buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), - false, buf.merged.dtype()); - buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(), - false, buf.merged.dtype()); - buf.residual[i] = 0; + stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(), + false, stage.merged.dtype()); + stage.residual[i] = NDArray(stage.merged.shape(), src[i].ctx(), false, + stage.merged.dtype()); + stage.residual[i] = 0; + int64_t small_size = + gc_->GetCompressedSize(stage.merged.shape().Size()); + stage.compressed_recv_buf[i] = + NDArray(TShape{small_size}, stage.merged.ctx(), false, + stage.merged.dtype()); + stage.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(), + false, stage.merged.dtype()); + } + } else if (!buf.merged.is_none()) { + if (buf.copy_buf.empty() && stage.copy_buf.empty()) { + buf.copy_buf.resize(group1.size() + 1); + buf.compressed_recv_buf.resize(group1.size() + 1); + buf.compressed_send_buf.resize(group1.size() + 1); + buf.residual.resize(group1.size()); + stage.copy_buf.resize(group2.size()); + stage.compressed_recv_buf.resize(group2.size()); + stage.compressed_send_buf.resize(group2.size()); + stage.residual.resize(group2.size()); + for (size_t i = 0, j = 0, k = 0; i < src.size(); ++i) { + int id = src[i].ctx().dev_id; + if (id < NVLINK_SUPPORT) { + buf.copy_buf[j] = NDArray(buf.merged.shape(), buf.merged.ctx(), + false, buf.merged.dtype()); + buf.residual[j] = NDArray(buf.merged.shape(), src[i].ctx(), false, + buf.merged.dtype()); + buf.residual[j] = 0; + int64_t small_size = + gc_->GetCompressedSize(buf.merged.shape().Size()); + buf.compressed_recv_buf[j] = + NDArray(TShape{small_size}, buf.merged.ctx(), false, + buf.merged.dtype()); + buf.compressed_send_buf[j] = NDArray( + TShape{small_size}, src[i].ctx(), false, buf.merged.dtype()); + j++; + } else { + stage.copy_buf[k] = + NDArray(stage.merged.shape(), stage.merged.ctx(), false, + stage.merged.dtype()); + stage.residual[k] = NDArray(stage.merged.shape(), src[i].ctx(), + false, stage.merged.dtype()); + stage.residual[k] = 0; + int64_t small_size = + gc_->GetCompressedSize(stage.merged.shape().Size()); + stage.compressed_recv_buf[k] = + NDArray(TShape{small_size}, stage.merged.ctx(), false, + stage.merged.dtype()); + stage.compressed_send_buf[k] = NDArray( + TShape{small_size}, src[i].ctx(), false, stage.merged.dtype()); + k++; + } + } + buf.copy_buf[group1.size()] = NDArray( + buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype()); int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size()); - buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(), - false, buf.merged.dtype()); - buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(), - false, buf.merged.dtype()); + buf.compressed_recv_buf[group1.size()] = NDArray( + TShape{small_size}, buf.merged.ctx(), false, buf.merged.dtype()); + buf.compressed_send_buf[group1.size()] = NDArray( + TShape{small_size}, stage.merged.ctx(), false, buf.merged.dtype()); } } + std::vector reduce_s(stage.copy_buf.size()); + std::vector reduce(buf.copy_buf.size()); + + for (size_t i = 0, j = 0, k = 0; i < src.size(); ++i) { + int id = src[i].ctx().dev_id; + if (id >= NVLINK_SUPPORT || buf.merged.is_none()) { + // compress before copy + // this is done even if the data is on same context as copy_buf because + // we don't want the training to be biased towards data on this GPU + gc_->Quantize(src[i], &(stage.compressed_send_buf[j]), + &(stage.residual[j]), priority); + + if (stage.compressed_send_buf[j].ctx() != + stage.compressed_recv_buf[j].ctx()) { + CopyFromTo(stage.compressed_send_buf[j], + &(stage.compressed_recv_buf[j]), priority); + } else { + // avoid memory copy when they are on same context + stage.compressed_recv_buf[j] = stage.compressed_send_buf[j]; + } - for (size_t i = 0; i < src.size(); ++i) { - // compress before copy - // this is done even if the data is on same context as copy_buf because - // we don't want the training to be biased towards data on this GPU - gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority); - - if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) { - CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority); + gc_->Dequantize(stage.compressed_recv_buf[j], &(stage.copy_buf[j]), + priority); + reduce_s[j] = stage.copy_buf[j]; + j++; } else { - // avoid memory copy when they are on same context - buf.compressed_recv_buf[i] = buf.compressed_send_buf[i]; - } + gc_->Quantize(src[i], &(buf.compressed_send_buf[k]), &(buf.residual[k]), + priority); + + if (buf.compressed_send_buf[k].ctx() != + buf.compressed_recv_buf[k].ctx()) { + CopyFromTo(buf.compressed_send_buf[k], &(buf.compressed_recv_buf[k]), + priority); + } else { + // avoid memory copy when they are on same context + buf.compressed_recv_buf[k] = buf.compressed_send_buf[k]; + } - gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority); - reduce[i] = buf.copy_buf[i]; + gc_->Dequantize(buf.compressed_recv_buf[k], &(buf.copy_buf[k]), + priority); + reduce[k] = buf.copy_buf[k]; + k++; + } + } + ElementwiseSum(reduce_s, &stage.merged); + if (buf.merged.is_none()) { + return stage.merged; + } else { + gc_->Quantize(stage.merged, &buf.compressed_send_buf[group1.size()], + &(buf.residual[buf.residual.size() - 1]), priority); + CopyFromTo(buf.compressed_send_buf[group1.size()], + &(buf.compressed_recv_buf[group1.size()]), priority); + gc_->Dequantize(buf.compressed_recv_buf[group1.size()], + &(buf.copy_buf[group1.size()]), priority); + reduce[reduce.size() - 1] = buf.copy_buf[group1.size()]; + ElementwiseSum(reduce, &buf.merged); } - ElementwiseSum(reduce, &buf.merged); + return buf.merged; } - void Broadcast(int key, const NDArray& src, - const std::vector dst, int priority) override { + void Broadcast(int key, const NDArray& src, const std::vector dst, + int priority) override { if (!inited_) { // copy to a random device first int dev_id = key % dst.size(); @@ -622,20 +777,24 @@ class CommDevice : public Comm { } } } else { - auto& buf = merge_buf_[key]; - CopyFromTo(src, &buf.merged, priority); + BufferEntry& buf = merge_buf_[key]; + BufferEntry& stage = stage_buf_[key]; + if (!buf.merged.is_none()) CopyFromTo(src, &buf.merged, priority); + CopyFromTo(src, &stage.merged, priority); for (auto d : dst) { - CopyFromTo(buf.merged, d, priority); + if (d->ctx().dev_id >= NVLINK_SUPPORT || buf.merged.is_none()) + CopyFromTo(stage.merged, d, priority); + else + CopyFromTo(buf.merged, d, priority); } } } void BroadcastRowSparse(int key, const NDArray& src, const std::vector>& dst, - const bool use_copy, - const int priority) override { + const bool use_copy, const int priority) override { CHECK_EQ(src.storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row-sparse src NDArray"; + << "BroadcastRowSparse expects row-sparse src NDArray"; for (size_t i = 0; i < dst.size(); ++i) { NDArray* out = dst[i].first; @@ -644,38 +803,44 @@ class CommDevice : public Comm { CopyFromTo(src, out, priority); } else { CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row_sparse dst NDArray"; + << "BroadcastRowSparse expects row_sparse dst NDArray"; const bool is_diff_ctx = out->ctx() != src.ctx(); - NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(), - src.ctx(), true, out->dtype(), out->aux_types()) : *out; + NDArray out_gpu = + is_diff_ctx ? NDArray(kRowSparseStorage, out->shape(), src.ctx(), + true, out->dtype(), out->aux_types()) + : *out; CHECK_EQ(row_id.ctx(), src.ctx()) - << "row_id and src are expected to be on the same context"; - - Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - NDArray temp = out_gpu; - const TBlob& indices = row_id.data(); - switch (temp.ctx().dev_mask()) { - case cpu::kDevMask: { - mxnet::common::SparseRetainOpForwardRspWrapper(rctx.get_stream(), - src, indices, kWriteTo, &temp); - break; - } + << "row_id and src are expected to be on the same context"; + + Engine::Get()->PushAsync( + [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + NDArray temp = out_gpu; + const TBlob& indices = row_id.data(); + switch (temp.ctx().dev_mask()) { + case cpu::kDevMask: { + mxnet::common::SparseRetainOpForwardRspWrapper( + rctx.get_stream(), src, indices, kWriteTo, &temp); + break; + } #if MXNET_USE_CUDA - case gpu::kDevMask: { - mxnet::common::SparseRetainOpForwardRspWrapper(rctx.get_stream(), - src, indices, kWriteTo, &temp); - // wait for GPU operations to complete - rctx.get_stream()->Wait(); - break; - } + case gpu::kDevMask: { + mxnet::common::SparseRetainOpForwardRspWrapper( + rctx.get_stream(), src, indices, kWriteTo, &temp); + // wait for GPU operations to complete + rctx.get_stream()->Wait(); + break; + } #endif - default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; - } - on_complete(); - }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()}, - FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain")); + default: + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + } + on_complete(); + }, + out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()}, + FnProperty::kNormal, priority, + PROFILER_MESSAGE("KVStoreSparseRetain")); if (is_diff_ctx) { CopyFromTo(out_gpu, out, priority); } @@ -694,7 +859,7 @@ class CommDevice : public Comm { } int n = static_cast(gpus.size()); int enabled = 0; - std::vector p2p(n*n); + std::vector p2p(n * n); for (int i = 0; i < n; ++i) { cudaSetDevice(gpus[i]); for (int j = 0; j < n; j++) { @@ -704,21 +869,21 @@ class CommDevice : public Comm { cudaError_t e = cudaDeviceEnablePeerAccess(gpus[j], 0); if (e == cudaSuccess || e == cudaErrorPeerAccessAlreadyEnabled) { ++enabled; - p2p[i*n+j] = 1; + p2p[i * n + j] = 1; } } } } - if (enabled != n*(n-1)) { + if (enabled != n * (n - 1)) { // print warning info if not fully enabled - LOG(WARNING) << "only " << enabled << " out of " - << n*(n-1) << " GPU pairs are enabled direct access. " + LOG(WARNING) << "only " << enabled << " out of " << n * (n - 1) + << " GPU pairs are enabled direct access. " << "It may affect the performance. " << "You can set MXNET_ENABLE_GPU_P2P=0 to turn it off"; std::string access(n, '.'); for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { - access[j] = p2p[i*n+j] ? 'v' : '.'; + access[j] = p2p[i * n + j] ? 'v' : '.'; } LOG(WARNING) << access; } @@ -729,40 +894,109 @@ class CommDevice : public Comm { using KeyAttrs = std::tuple; // try to allocate buff on device evenly void InitMergeBuffer(const std::vector& devs) { - std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( - const KeyAttrs& a, const KeyAttrs& b) { - return std::get<1>(a).Size() > std::get<1>(b).Size(); - }); - - std::unordered_map> ctx_info; - for (auto d : devs) { - ctx_info[d.dev_id] = std::make_pair(d, 0); + std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), + [](const KeyAttrs& a, const KeyAttrs& b) { + return std::get<1>(a).Size() > std::get<1>(b).Size(); + }); + + for (auto& d : devs) { + if (d.dev_id < NVLINK_SUPPORT) + group1.push_back(d); + else + group2.push_back(d); } - for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { - const int key = std::get<0>(sorted_key_attrs_[i]); - const TShape& shape = std::get<1>(sorted_key_attrs_[i]); - const int type = std::get<2>(sorted_key_attrs_[i]); - const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]); - auto& buf = merge_buf_[key]; - Context ctx; - size_t min_size = std::numeric_limits::max(); - for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) { - size_t size = it->second.second; - if (size <= min_size) { - ctx = it->second.first; - min_size = size; + if (group1.empty() || group2.empty()) { + // all gpus are all connected by NVLinks: use all-to-all + std::unordered_map> ctx_info; + for (auto d : devs) { + ctx_info[d.dev_id] = std::make_pair(d, 0); + } + for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { + const int key = std::get<0>(sorted_key_attrs_[i]); + const TShape shape = std::get<1>(sorted_key_attrs_[i]); + const int type = std::get<2>(sorted_key_attrs_[i]); + const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]); + BufferEntry& stage = stage_buf_[key]; + Context ctx; + size_t min_size = std::numeric_limits::max(); + for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) { + size_t size = it->second.second; + if (size <= min_size) { + ctx = it->second.first; + min_size = size; + } + } + if (stype == kDefaultStorage) { + stage.merged = NDArray(shape, ctx, false, type); + } else { + stage.merged = NDArray(stype, shape, ctx, true, type); } + ctx_info[ctx.dev_id].second += shape.Size(); } - if (stype == kDefaultStorage) { - buf.merged = NDArray(shape, ctx, false, type); - } else { - buf.merged = NDArray(stype, shape, ctx, true, type); + } else { + // QPI connections are included: use spanning tree + size_t gpu0, gpu1; + // gpu0 and gpu1 hold the gpu indexes connected by nvlink between group1 + // and group2 groups accordingly + for (gpu0 = 0, gpu1 = 0; gpu0 < group1.size() && gpu1 < group2.size();) { + if (group2[gpu1].dev_id - group1[gpu0].dev_id == NVLINK_SUPPORT) + break; + else if (group2[gpu1].dev_id - group1[gpu0].dev_id > NVLINK_SUPPORT) + gpu0++; + else + gpu1++; + } + if (gpu0 == group1.size() || gpu1 == group2.size()) gpu0 = gpu1 = 0; + for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { + const int key = std::get<0>(sorted_key_attrs_[i]); + const TShape shape = std::get<1>(sorted_key_attrs_[i]); + const int type = std::get<2>(sorted_key_attrs_[i]); + const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]); + BufferEntry& buf = merge_buf_[key]; + BufferEntry& stage = stage_buf_[key]; + if (stype == kDefaultStorage) { + buf.merged = NDArray(shape, group1[gpu0], false, type); + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(group1.size()); + for (size_t i = 0; i < group1.size(); ++i) + buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), + false, buf.merged.dtype()); + } + + stage.merged = NDArray(shape, group2[gpu1], false, type); + if (stage.copy_buf.empty()) { + stage.copy_buf.resize(group2.size() - 1); + for (size_t i = 0; i < group2.size() - 1; ++i) + stage.copy_buf[i] = + NDArray(stage.merged.shape(), stage.merged.ctx(), false, + stage.merged.dtype()); + } + } else { + buf.merged = NDArray(stype, shape, group1[gpu0], true, type); + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(group1.size() + 1); + for (size_t i = 0; i < group1.size() + 1; ++i) + buf.copy_buf[i] = + NDArray(stype, buf.merged.shape(), buf.merged.ctx(), true, + buf.merged.dtype()); + } + + stage.merged = NDArray(stype, shape, group2[gpu1], true, type); + if (stage.copy_buf.empty()) { + stage.copy_buf.resize(group2.size()); + for (size_t i = 0; i < group2.size(); ++i) + stage.copy_buf[i] = + NDArray(stype, stage.merged.shape(), stage.merged.ctx(), true, + stage.merged.dtype()); + } + } } - ctx_info[ctx.dev_id].second += shape.Size(); } inited_ = true; } + /// \brief the NVLinked connected gpu groups + std::vector group1, group2; std::vector sorted_key_attrs_; /// \brief temporal space for pushing and pulling struct BufferEntry { @@ -778,6 +1012,8 @@ class CommDevice : public Comm { std::vector compressed_recv_buf; }; std::unordered_map merge_buf_; + /// \brief the small buffer for partially merged data + std::unordered_map stage_buf_; bool inited_; };