Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-331] NVLink communication pattern updated #8915

Closed
Laurawly wants to merge 8 commits intoapache:masterfrom
Laurawly:master
Closed

[MXNET-331] NVLink communication pattern updated #8915
Laurawly wants to merge 8 commits intoapache:masterfrom
Laurawly:master

Conversation

@Laurawly
Copy link
Copy Markdown

@Laurawly Laurawly commented Dec 1, 2017

Description

(Optimized kvstore communication pattern to make full use of NVLinks)

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • For user-facing API changes, API doc string has been updated. For new C++ functions in header files, their functionalities and arguments are well-documented.
  • To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Device reduce
  • Device broadcast

Comments

  • The changes make kvstore to use more NVLinks/PCIe and avoid using QPI when both are present, edge cases include 4 and 8 gpus are used.

Copy link
Copy Markdown
Contributor

@weixingzhang weixingzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation

Comment thread src/kvstore/comm.h Outdated
CopyFromTo(src[0], &buf.merged, priority);
return buf.merged;
CopyFromTo(src[0], &buf.merged, priority);
return buf.merged;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation need to be corrected.

Comment thread src/kvstore/comm.h Outdated
ElementwiseSum(reduce, &buf.merged);
CopyFromTo(stage.merged, &(buf.copy_buf[buf.copy_buf.size()-1]), priority);
reduce[reduce.size()-1] = buf.copy_buf[buf.copy_buf.size()-1];
ElementwiseSum(reduce, &buf.merged);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing priority? Should this line be ElementwiseSum(reduce, &buf.merged, priority);?

Comment thread src/kvstore/comm.h Outdated
for (size_t i = 0; i < dst.size(); ++i) {
if (i != static_cast<size_t>(dev_id)) {
CopyFromTo(*dst[dev_id], dst[i], priority);
CopyFromTo(*dst[dev_id], (dst[i]), priority);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why there is a bracket, (dst[i])

Comment thread src/kvstore/comm.h Outdated
// copy to a random device first
int dev_id = key % dst.size();
CopyFromTo(src, dst[dev_id], priority);
CopyFromTo(src, (dst[dev_id]), priority);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why there is a bracket, dst[dev_id]

Comment thread src/kvstore/comm.h Outdated

std::vector<Context> g1, g2;
for (auto& d : devs) {
if (d.dev_id < 4) g1.push_back(d);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we decide this by querying cuda instead of magic numbers?

@piiswrong
Copy link
Copy Markdown
Contributor

Looks like this is not turned off when not using nvlink?

Comment thread src/kvstore/comm.h Outdated
buf.copy_buf[i] = NDArray(
buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
}
if (buf.merged.is_none()&& stage.copy_buf.empty()) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space before &&

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think buf.merged.is_none() will ever be True since InitBuffersAndComm has intialized buf.merged?

Comment thread src/kvstore/comm.h
std::vector<NDArray> compressed_recv_buf;
};
std::unordered_map<int, BufferEntry> merge_buf_;
std::unordered_map<int, BufferEntry> stage_buf_;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add brief description for what this is for

Comment thread src/kvstore/comm.h
stage.merged = NDArray(s, ctx, false, type);
ctx_info[ctx.dev_id].second += s.Size();
}
} else {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the impact of this update on older devices/architectures?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll avoid using QPIs.

@eric-haibin-lin
Copy link
Copy Markdown
Member

@rahul003 please help review

@Laurawly
Copy link
Copy Markdown
Author

Laurawly commented Dec 5, 2017

@piiswrong When not using NVLink, this method uses more PCIe express than QPI which also accelerates the original communications.

@eric-haibin-lin eric-haibin-lin self-assigned this Dec 6, 2017
@rahul003
Copy link
Copy Markdown
Member

rahul003 commented Dec 8, 2017

@Laurawly Can we also make use of this feature for ReduceCompressed function?

Comment thread src/kvstore/comm.h Outdated
return std::get<1>(a).Size() > std::get<1>(b).Size();
});

std::vector<Context> g1, g2;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there more readable var names for g1 and g2 to explain its purpose?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So g represents group here. It means I separate the GPU cards into two communication groups.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some description for g1 and g2, and moved them to class members.

@Laurawly
Copy link
Copy Markdown
Author

Laurawly commented Dec 13, 2017

@rahul003 Yeah, I'll do an update on ReduceCompressed function as well. Thanks for reminding that.

@Laurawly
Copy link
Copy Markdown
Author

@rahul003 Could you review my updates in ReduceCompressed, thanks in advance!

@szha
Copy link
Copy Markdown
Member

szha commented Dec 22, 2017

@rahul003 ping

@Laurawly Laurawly force-pushed the master branch 4 times, most recently from 067b612 to e10997d Compare January 5, 2018 19:28
@eric-haibin-lin
Copy link
Copy Markdown
Member

Any idea why test_rsp_pull failed?

Comment thread src/kvstore/comm.h
pinned_ctx_ = Context::CPUPinned(0);
}
virtual ~Comm() { }
Comm() { pinned_ctx_ = Context::CPUPinned(0); }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we initialize on the constructor initialization list? It's more efficient.

Comment thread src/kvstore/comm.h
int key, const NDArray& src,
const std::vector<NDArray*> dst, int priority) = 0;
virtual void Broadcast(int key, const NDArray& src,
const std::vector<NDArray*> dst, int priority) = 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector is passed by value, shouldn't it be passed by ref?

Comment thread src/kvstore/comm.h Outdated
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());
if (buf.copy_buf.empty()) {
auto& stage = stage_buf_[key];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would type the full type here for readability.

Comment thread src/kvstore/comm.h

void Broadcast(int key, const NDArray& src,
const std::vector<NDArray*> dst, int priority) override {
void Broadcast(int key, const NDArray& src, const std::vector<NDArray*> dst,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be by ref?

@larroy
Copy link
Copy Markdown
Contributor

larroy commented Jan 11, 2018

Why is this file such a big header and no impl?

@eric-haibin-lin
Copy link
Copy Markdown
Member

@rahul003 can you review the changes made for grad compression? Thanks!

Copy link
Copy Markdown
Member

@rahul003 rahul003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for modifying reduceCompressed too.
Could you please add a few comments to reduce or reduceCompressed explaining the flow of data? That would make this easier for others to maintain or develop further.

Comment thread src/kvstore/comm.h Outdated
}

/// \brief the NVLinked connected gpu groups
std::vector<Context> g1, g2;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more readable to expand the names of these variables?

Comment thread src/kvstore/comm.h
}
} else {
// QPI connections are included: use spanning tree
size_t gpu0, gpu1;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments on what the below computation is doing, what does gpu0 and gpu1 hold?

Comment thread src/kvstore/comm.h Outdated
int id = src[i].ctx().dev_id;
if ((!buf.merged.is_none() && id == stage.merged.ctx().dev_id) ||
(buf.merged.is_none() && i == 0)) {
CopyFromTo(src[i], &(stage.merged), priority);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to copy src[i] onto the same ctx? Can we directly use src[i]

Comment thread src/kvstore/comm.h Outdated
buf.copy_buf.resize(g1.size() + 1);
buf.compressed_recv_buf.resize(g1.size() + 1);
buf.compressed_send_buf.resize(g1.size() + 1);
buf.residual.resize(g1.size() + 1);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are declaring g1.size()+1 as size of array for residuals. Residuals are not sent to other GPUs. We don't need to allocate one extra residual array

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That extra array is for copying back reduced value from stage buffer

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is the residual array only remains on the original GPU. It is never sent anywhere, but is only updated in place. Or are you just declaring an extra array for residual so that you can index this array similar to the other arrays (copy_buf or compressed_recv_buf)?

Either way, we can avoid creation of an extra residual array, right?
That would be significant memory savings (equal to the parameters of the model).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see what you mean. Yeah, that's right. I'll correct it accordingly.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested tests/nightly/test_kvstore.py and it passes.

@rahul003
Copy link
Copy Markdown
Member

rahul003 commented Jan 25, 2018

Could you please do these three things,

  1. add some comments on the flow of data for the reduce function? The flow is not easy to follow with code itself.
  2. ensure that tests/nightly/test_kvstore.py passes
  3. fix the extra residual issue

@Laurawly
Copy link
Copy Markdown
Author

@eric-haibin-lin Could you take a look if test_rsp_pull passes now?

@rahul003
Copy link
Copy Markdown
Member

rahul003 commented Jan 26, 2018

No, it is still failing.

test_kvstore_gpu.test_rsp_push_pull ... terminate called after throwing an instance of 'dmlc::Error'

  what():  [23:44:37] src/engine/./threaded_engine.h:359: [23:44:37] src/ndarray/ndarray_function.cc:181: ElementwiseSum<cpu> has not been implemented for storage_type = << 0


Stack trace returned 10 entries:

[bt] (0) /workspace/python/mxnet/../../lib/libmxnet.so(dmlc::StackTrace[abi:cxx11]()+0x5a) [0x7f9f39b7334a]

[bt] (1) /workspace/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7f9f39b73ee8]

[bt] (2) /workspace/python/mxnet/../../lib/libmxnet.so(void mxnet::ndarray::ElementwiseSum<mshadow::cpu>(mshadow::Stream<mshadow::cpu>*, mxnet::Resource const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> > const&, mxnet::NDArray*)+0x6a) [0x7f9f3c16f8aa]

[bt] (3) /workspace/python/mxnet/../../lib/libmxnet.so(+0x2f0ad9d) [0x7f9f3c1a5d9d]

[bt] (4) /workspace/python/mxnet/../../lib/libmxnet.so(+0x330848b) [0x7f9f3c5a348b]

[bt] (5) /workspace/python/mxnet/../../lib/libmxnet.so(mxnet::engine::ThreadedEngine::ExecuteOprBlock(mxnet::RunContext, mxnet::engine::OprBlock*)+0x100) [0x7f9f3c5afb50]

[bt] (6) /workspace/python/mxnet/../../lib/libmxnet.so(std::_Function_handler<void (std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>), mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#1}::operator()() const::{lambda(std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>)#1}>::_M_invoke(std::_Any_data const&, std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>&&)+0xe2) [0x7f9f3c5b7b42]

[bt] (7) /workspace/python/mxnet/../../lib/libmxnet.so(std::thread::_Impl<std::_Bind_simple<std::function<void (std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>)> (std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>)> >::_M_run()+0x4a) [0x7f9f3c5b210a]

[bt] (8) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f9f45601c80]

[bt] (9) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f9f4cf136ba]



A fatal error occurred in asynchronous engine operation. If you do not know what caused this error, you can try set environment variable MXNET_ENGINE_TYPE to NaiveEngine and run with debugger (i.e. gdb). This will force all operations to be synchronous and backtrace will give you the series of calls that lead to this error. Remember to set MXNET_ENGINE_TYPE back to empty after debugging.


Stack trace returned 8 entries:

[bt] (0) /workspace/python/mxnet/../../lib/libmxnet.so(dmlc::StackTrace[abi:cxx11]()+0x5a) [0x7f9f39b7334a]

[bt] (1) /workspace/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7f9f39b73ee8]

[bt] (2) /workspace/python/mxnet/../../lib/libmxnet.so(mxnet::engine::ThreadedEngine::ExecuteOprBlock(mxnet::RunContext, mxnet::engine::OprBlock*)+0x39a) [0x7f9f3c5afdea]

[bt] (3) /workspace/python/mxnet/../../lib/libmxnet.so(std::_Function_handler<void (std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>), mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#1}::operator()() const::{lambda(std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>)#1}>::_M_invoke(std::_Any_data const&, std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>&&)+0xe2) [0x7f9f3c5b7b42]

[bt] (4) /workspace/python/mxnet/../../lib/libmxnet.so(std::thread::_Impl<std::_Bind_simple<std::function<void (std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>)> (std::shared_ptr<mxnet::engine::ThreadPool::SimpleEvent>)> >::_M_run()+0x4a) [0x7f9f3c5b210a]

[bt] (5) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f9f45601c80]

[bt] (6) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f9f4cf136ba]

[bt] (7) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f9f4cc493dd]


MKL Build:20171227

@Laurawly Laurawly force-pushed the master branch 2 times, most recently from 7332be1 to d3aeed5 Compare January 29, 2018 19:47
@Laurawly
Copy link
Copy Markdown
Author

Laurawly commented Jan 29, 2018

@rahul003 should be solved by commit 683653e

@Laurawly
Copy link
Copy Markdown
Author

@piiswrong ping.

Comment thread src/kvstore/comm.h
int mask = src.ctx().dev_mask();
if (mask == Context::kCPU) {
for (auto d : dst) CopyFromTo(src, d, priority);
for (auto& d : dst) CopyFromTo(src, d, priority);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is broadcast not leveraging the new comm pattern? The same logic can be applied to fully utilize the bandwidth during copy, right? Or is the plan to do that in the next PR?

Comment thread src/kvstore/comm.h
rctx.get_stream<gpu>()->Wait();
break;
}
case gpu::kDevMask: {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong indentation?

Copy link
Copy Markdown
Contributor

@piiswrong piiswrong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two general points

  1. Too many irrelevant cosmetic changes
  2. the magic number 4 appears a lot of times. Are you assuming there are 4 gpus? This should be queried dynamically instead of being constant.

Comment thread src/kvstore/comm.h
*/
#ifndef MXNET_KVSTORE_COMM_H_
#define MXNET_KVSTORE_COMM_H_
#define NVLINK_SUPPORT 4
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this? Can we avoid magic numbers?

Comment thread src/kvstore/comm.h
on_complete();
}, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sort of cosmetic changes are really distracting for code review. Try not to do it next time

Comment thread src/kvstore/comm.h
reduce[0] = buf.merged;

if (buf.copy_buf.empty()) {
// TODO(mli) this results in large device memory usage for huge ndarray,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO handled by this PR?

Comment thread src/kvstore/comm.h
inline static void ReduceSumCPU(
const std::vector<DType*> &dptr, size_t offset, index_t size) {
template <typename DType>
inline static void ReduceSumCPU(const std::vector<DType*>& dptr,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes are really annoying

Comment thread src/kvstore/comm.h
reduce_s.resize(stage.copy_buf.size());
for (size_t i = 0, j = 0; i < src.size(); ++i) {
int id = src[i].ctx().dev_id;
if (id >= 4 || buf.merged.is_none()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 4? can we avoid magic numbers?

@CodingCat
Copy link
Copy Markdown
Contributor

Hi, the community has passed to vote about associating the code changes with JIRA (https://lists.apache.org/thread.html/ab22cf0e35f1bce2c3bf3bec2bc5b85a9583a3fe7fd56ba1bbade55f@%3Cdev.mxnet.apache.org%3E)

We have updated the guidelines for contributors in https://cwiki.apache.org/confluence/display/MXNET/Development+Process, please ensure that you have created a JIRA at https://issues.apache.org/jira/projects/MXNET/issues/ to describe your work in this pull request and include the JIRA title in your PR as [MXNET-xxxx] your title where MXNET-xxxx is the JIRA id

Thanks!

@Jerryzcn
Copy link
Copy Markdown
Contributor

when can we expect this to be merged?

@marcoabreu
Copy link
Copy Markdown
Contributor

@Jerryzcn we're waiting for @Laurawly to address the review comments

@Laurawly Laurawly changed the title NVLink communication pattern updated [MXNET-331]NVLink communication pattern updated Apr 18, 2018
@Laurawly Laurawly changed the title [MXNET-331]NVLink communication pattern updated [MXNET-331] NVLink communication pattern updated Apr 18, 2018
@eric-haibin-lin
Copy link
Copy Markdown
Member

Closing this for now due to inactivity.

@eric-haibin-lin
Copy link
Copy Markdown
Member

Moved to #11357

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.