From e467e47f88376d2ed21ada2a1f753bee6e2d0f2c Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 14 Sep 2015 09:47:00 -0700 Subject: [PATCH 1/4] merge with min's pR --- src/symbol/graph_executor.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 8e08952eb234..498ec4f942e0 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -205,6 +205,9 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { } } + for (const Resource& r : op_node.op_ctx.requested) { + exec.mutate_vars.push_back(static_cast(r.var)); + } // start setup exec function. for (const Resource& r : op_node.op_ctx.requested) { exec.mutate_vars.push_back(static_cast(r.var)); From 253cd7976f73e5bef5b4b8511cb69d8dfa447e42 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 14 Sep 2015 12:39:31 -0700 Subject: [PATCH 2/4] add threaded engine per device --- dmlc-core | 2 +- src/engine/engine.cc | 10 +- src/engine/engine_impl.h | 9 +- src/engine/stream_manager.h | 27 ++-- src/engine/thread_pool.h | 23 ++-- src/engine/threaded_engine_perdevice.cc | 170 ++++++++++++++++++++++++ src/engine/threaded_engine_pooled.cc | 12 +- 7 files changed, 215 insertions(+), 38 deletions(-) create mode 100644 src/engine/threaded_engine_perdevice.cc diff --git a/dmlc-core b/dmlc-core index 75f1950d386d..2e2d187efc43 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 75f1950d386d033b0b64919017515d27e698962a +Subproject commit 2e2d187efc43ee2df1d132c3690169575e830442 diff --git a/src/engine/engine.cc b/src/engine/engine.cc index 75bb58c6f56a..9698b490157a 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -15,12 +15,16 @@ inline Engine* CreateEngine() { const bool default_engine = (type == nullptr); if (type == nullptr) type = "ThreadedEngine"; std::string stype = type; + Engine *ret = nullptr; - if (stype == "ThreadedEngine") { - ret = CreateThreadedEngine(); - } else if (stype == "NaiveEngine") { + if (stype == "NaiveEngine") { ret = CreateNaiveEngine(); + } else if (stype == "ThreadedEngine") { + ret = CreateThreadedEnginePooled(); + } else if (stype == "ThreadedEnginePerDevie") { + ret = CreateThreadedEnginePerDevice(); } + CHECK_NE(ret, nullptr) << "Cannot find Eine " << type << " in registry"; if (!default_engine) { diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h index e4c350656097..44452df7b9c5 100644 --- a/src/engine/engine_impl.h +++ b/src/engine/engine_impl.h @@ -65,11 +65,16 @@ inline T* Opr::Cast() { #endif } +/*! \brief Maximum number of GPUs */ +static constexpr std::size_t kMaxNumGPUs = 16; + // predeclare factory function for each type of engine /*! \return NaiveEngine instance */ Engine *CreateNaiveEngine(); -/*! \return ThreadedEngine instance */ -Engine *CreateThreadedEngine(); +/*! \return ThreadedEnginePooled instance */ +Engine *CreateThreadedEnginePooled(); +/*! \return ThreadedEnginePerDevie instance */ +Engine *CreateThreadedEnginePerDevice(); } // namespace engine } // namespace mxnet #endif // MXNET_ENGINE_ENGINE_IMPL_H_ diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 7b2382d60df7..05038d443202 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -13,7 +13,6 @@ #include "../common/cuda_utils.h" namespace mxnet { - namespace engine { /*! @@ -44,9 +43,9 @@ class StreamManager { template RunContext StreamManager::GetRunContext( Context const& ctx) { + RunContext ret; switch (ctx.dev_mask) { - case cpu::kDevMask: - return {nullptr}; + case cpu::kDevMask: ret.stream = nullptr; break; case gpu::kDevMask: { #if MXNET_USE_CUDA std::size_t use_counter; @@ -63,21 +62,22 @@ RunContext StreamManager::GetRunContext( use_counter = counter; counter = (counter + 1) % kStreams; } - return {gpu_streams_.at(ctx.dev_id).at(use_counter)}; -#else // MXNET_USE_CUDA - LOG(FATAL) << "Please compile with CUDA enabled"; + ret.stream = gpu_streams_.at(ctx.dev_id).at(use_counter); + break; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA } } - return {nullptr}; + return ret; } template RunContext StreamManager::GetIORunContext( Context const& ctx) { + RunContext ret; switch (ctx.dev_mask) { - case cpu::kDevMask: - return {nullptr}; + case cpu::kDevMask: ret.stream = nullptr; break; case gpu::kDevMask: { #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(ctx.dev_id)); @@ -87,13 +87,14 @@ RunContext StreamManager::GetIORunContext( gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(false, false); } } - return {gpu_io_streams_.at(ctx.dev_id)}; -#else // MXNET_USE_CUDA - LOG(FATAL) << "Please compile with CUDA enabled"; + ret.stream = gpu_io_streams_.at(ctx.dev_id); + break; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA } } - return {nullptr}; + return ret; } template diff --git a/src/engine/thread_pool.h b/src/engine/thread_pool.h index ef99a93e58d1..acbf61896df4 100644 --- a/src/engine/thread_pool.h +++ b/src/engine/thread_pool.h @@ -17,14 +17,14 @@ namespace engine { /*! * \brief Thread pool. */ -template class ThreadPool { public: /*! - * \brief Constructor takes function to run and its arguments. + * \brief Constructor takes function to run. + * \param size size of the thread pool. + * \param func the function to run on the thread pool. */ - template - explicit ThreadPool(Function&& func, Args&&... args); + explicit ThreadPool(size_t size, std::function func); /*! * \brief Destructor. */ @@ -34,7 +34,7 @@ class ThreadPool { /*! * \brief Worker threads. */ - std::array worker_threads_; + std::vector worker_threads_; /*! * \brief Disallow default construction. */ @@ -45,16 +45,14 @@ class ThreadPool { DISALLOW_COPY_AND_ASSIGN(ThreadPool); }; -template -template -ThreadPool::ThreadPool(Function&& func, Args&&... args) { - for (auto&& i : worker_threads_) { - i = std::thread{std::forward(func), std::forward(args)...}; +ThreadPool::ThreadPool(size_t size, std::function func) + : worker_threads_(size) { + for (auto& i : worker_threads_) { + i = std::thread(func); } } -template -ThreadPool::~ThreadPool() noexcept(false) { +ThreadPool::~ThreadPool() noexcept(false) { for (auto&& i : worker_threads_) { i.join(); } @@ -62,5 +60,4 @@ ThreadPool::~ThreadPool() noexcept(false) { } // namespace engine } // namespace mxnet - #endif // MXNET_ENGINE_THREAD_POOL_H_ diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc new file mode 100644 index 000000000000..09c12da13938 --- /dev/null +++ b/src/engine/threaded_engine_perdevice.cc @@ -0,0 +1,170 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file threaded_engine_perdevice.cc + * \brief ThreadedEngine that uses fix amount of thread for each device. + */ +#include +#include +#include +#include +#include "./threaded_engine.h" +#include "./thread_pool.h" +#include "./stream_manager.h" + +namespace mxnet { +namespace engine { +/*! + * \brief ThreadedEngine uses per device threads. + * The policy of this Engine: + * - Execute Async operation immediately if pushed from Pusher. + * - Use fixed amount of threads for each device. + * - Use special threads for copy operations. + * - Each stream is allocated and binded to each of the thread. + */ +class ThreadedEnginePerDevice : public ThreadedEngine { + public: + ThreadedEnginePerDevice() noexcept(false) { + cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 2); + gpu_worker_nthreads_ = dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 2); + gpu_copy_nthreads_ = dmlc::GetEnv("MXNET_GPU_COPY_NTHREADS", 1); + + // create CPU task + auto *cpu_queue = &(cpu_worker_.task_queue); + cpu_worker_.pool.reset(new ThreadPool( + cpu_worker_nthreads_, [this, cpu_queue] { + this->CPUWorker(cpu_queue); + })); + // GPU tasks will be created lazily + } + ~ThreadedEnginePerDevice() noexcept(false) { + } + + protected: + void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { + if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { + CHECK_EQ(opr_block->ctx.dev_mask, cpu::kDevMask); + RunContext run_ctx; + run_ctx.stream = nullptr; + this->ExecuteOprBlock(run_ctx, opr_block); + } else { + const Context& ctx = opr_block->ctx; + if (ctx.dev_mask == cpu::kDevMask) { + cpu_worker_.task_queue.Push(opr_block); + } else { + CHECK_EQ(ctx.dev_mask, gpu::kDevMask); + ThreadWorkerBlock* block = this->GetGPUWorkerBlock( + ctx.dev_id, opr_block->opr->prop); + block->task_queue.Push(opr_block); + } + } + } + + private: + // working unit for each of the task. + struct ThreadWorkerBlock { + // task queue on this task + dmlc::ConcurrentBlockingQueue task_queue; + // thread pool that works on this task + std::unique_ptr pool; + // destructor + ~ThreadWorkerBlock() noexcept(false) { + task_queue.SignalForKill(); + } + }; + /*! \brief number of concurrent thread cpu worker uses */ + int cpu_worker_nthreads_; + /*! \brief number of concurrent thread each gpu worker uses */ + int gpu_worker_nthreads_; + /*! \brief number of concurrent thread each gpu copy worker uses */ + int gpu_copy_nthreads_; + // mutex used when creating a ThreadWorkerBlock + std::mutex create_mutex_; + // cpu worker + ThreadWorkerBlock cpu_worker_; + // workers doing normal works on GPU + std::array, kMaxNumGPUs> gpu_normal_workers_; + // workers doing copy works from/to GPU + std::array, kMaxNumGPUs> gpu_copy_workers_; + /*! + * \brief get GPU Task Worker + * \param dev_id the device id + * \param prop The property of the function. + */ + inline ThreadWorkerBlock *GetGPUWorkerBlock(size_t dev_id, + FnProperty prop) { + bool is_copy = (prop == FnProperty::kCopy); + CHECK_LT(dev_id, kMaxNumGPUs) + << "GPU Device index " << dev_id + << " exceed bound " << kMaxNumGPUs; + std::array, kMaxNumGPUs> *workers; + if (is_copy) { + workers = &gpu_copy_workers_; + } else { + workers = &gpu_normal_workers_; + } + ThreadWorkerBlock *block = workers->at(dev_id).get(); + if (block != nullptr) return block; + { + // only lock when block is not available. + std::lock_guard lock(create_mutex_); + // need to double check, because state can change + ThreadWorkerBlock *block = workers->at(dev_id).get(); + if (block != nullptr) return block; + int nthread = is_copy ? gpu_copy_nthreads_ : gpu_worker_nthreads_; + workers->at(dev_id).reset(new ThreadWorkerBlock()); + block = workers->at(dev_id).get(); + block->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, block] () { + this->GPUWorker(dev_id, is_copy, &(block->task_queue)); + })); + return block; + } + } + /*! + * \brief GPU worker that performs operations on a certain device. + * \param dev_id The device id of the worker. + * \param is_copy_worker whether the worker only do copy job + * \param task_queue the device id of the worker. + */ + inline void GPUWorker(int dev_id, + bool is_copy_worker, + dmlc::ConcurrentBlockingQueue* task_queue) { + #if MXNET_USE_CUDA + // allocate stream + mshadow::SetDevice(dev_id); + RunContext run_ctx; + mshadow::Stream *stream; + if (is_copy_worker) { + stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); + } else { + stream = mshadow::NewStream(false, false); + } + run_ctx.stream = stream; + // execute task + OprBlock* opr_block; + while (task_queue->Pop(&opr_block)) { + this->ExecuteOprBlock(run_ctx, opr_block); + } + mshadow::DeleteStream(stream); + #endif + } + /*! + * \brief CPU worker that performs operations on CPU. + * \param task_queue the device id of the worker. + */ + inline void CPUWorker(dmlc::ConcurrentBlockingQueue* task_queue) { + RunContext run_ctx; + run_ctx.stream = nullptr; + // execute task + OprBlock* opr_block; + while (task_queue->Pop(&opr_block)) { + this->ExecuteOprBlock(run_ctx, opr_block); + } + } +}; + +Engine *CreateThreadedEnginePerDevice() { + return new ThreadedEnginePerDevice(); +} +} // namespace engine +} // namespace mxnet + diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 2dd2f27487eb..e8027eeea1f9 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -22,9 +22,9 @@ namespace engine { */ class ThreadedEnginePooled : public ThreadedEngine { public: - ThreadedEnginePooled() - : thread_pool_{[this]() { ThreadWorker(&task_queue_); }}, - io_thread_pool_{[this]() { ThreadWorker(&io_task_queue_); }} {} + ThreadedEnginePooled() : + thread_pool_(kNumWorkingThreads, [this]() { ThreadWorker(&task_queue_); }), + io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {} ~ThreadedEnginePooled() noexcept(false) { task_queue_.SignalForKill(); @@ -59,8 +59,8 @@ class ThreadedEnginePooled : public ThreadedEngine { /*! * \brief Thread pools. */ - ThreadPool thread_pool_; - ThreadPool<1> io_thread_pool_; + ThreadPool thread_pool_; + ThreadPool io_thread_pool_; /*! * \brief Worker. * \param task_queue Queue to work on. @@ -109,7 +109,7 @@ class ThreadedEnginePooled : public ThreadedEngine { } }; -Engine *CreateThreadedEngine() { +Engine *CreateThreadedEnginePooled() { return new ThreadedEnginePooled(); } } // namespace engine From daff15c20bc692835870a99e266f2c16bb019063 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 14 Sep 2015 12:57:52 -0700 Subject: [PATCH 3/4] Add Per Device Threaded Engine Policy, Explicit use copy --- .travis.yml | 1 + include/mxnet/engine.h | 6 +++-- scripts/travis_script.sh | 12 ++++++++-- src/engine/thread_pool.h | 32 ++++++++++--------------- src/engine/threaded_engine_perdevice.cc | 6 +++-- src/engine/threaded_engine_pooled.cc | 7 ++++-- src/ndarray/ndarray.cc | 14 +++++------ src/symbol/graph_executor.cc | 3 --- 8 files changed, 43 insertions(+), 38 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5c7a5d2562a6..1d9e5bad4ed3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,7 @@ env: - TASK=python CXX=g++ - TASK=python3 CXX=g++ - TASK=python_naive CXX=g++ + - TASK=python_perdev CXX=g++ - TASK=cpp_unittest CXX=g++ # dependent apt packages diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index f185da8215c3..6bcf7b3da3da 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -33,8 +33,10 @@ typedef Opr* OprHandle; enum class FnProperty { /*! \brief Normal operation */ kNormal, - /*! \brief Copy operation between CPU and GPU */ - kCopy, + /*! \brief Copy operation from GPU to other devices */ + kCopyFromGPU, + /*! \brief Copy operation from CPU to other devices */ + kCopyToGPU, /*! \brief Asynchronous function call */ kAsync }; // enum class FnProperty diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 1b250afdf70b..07abd4881dce 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -40,7 +40,7 @@ if [ ${TASK} == "python3" ]; then make all || exit -1 export MXNET_ENGINE_TYPE=ThreadedEngine nosetests tests/python/unittest || exit -1 - nosetests tests/python/train || exit -1 + nosetests tests/python/train || exit -1 fi if [ ${TASK} == "python_naive" ]; then @@ -48,7 +48,15 @@ if [ ${TASK} == "python_naive" ]; then make all || exit -1 export MXNET_ENGINE_TYPE=NaiveEngine nosetests tests/python/unittest || exit -1 - nosetests tests/python/train || exit -1 + nosetests tests/python/train || exit -1 +fi + +if [ ${TASK} == "python_perdev" ]; then + echo "USE_CUDA=0" >> config.mk + make all || exit -1 + export MXNET_ENGINE_TYPE=ThreadedEnginePerDevice + nosetests tests/python/unittest || exit -1 + nosetests tests/python/train || exit -1 fi if [ ${TASK} == "cpp_unittest" ]; then diff --git a/src/engine/thread_pool.h b/src/engine/thread_pool.h index acbf61896df4..b88cddaa29c5 100644 --- a/src/engine/thread_pool.h +++ b/src/engine/thread_pool.h @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include "mxnet/base.h" @@ -24,11 +24,17 @@ class ThreadPool { * \param size size of the thread pool. * \param func the function to run on the thread pool. */ - explicit ThreadPool(size_t size, std::function func); - /*! - * \brief Destructor. - */ - ~ThreadPool() noexcept(false); + explicit ThreadPool(size_t size, std::function func) + : worker_threads_(size) { + for (auto& i : worker_threads_) { + i = std::thread(func); + } + } + ~ThreadPool() noexcept(false) { + for (auto&& i : worker_threads_) { + i.join(); + } + } private: /*! @@ -44,20 +50,6 @@ class ThreadPool { */ DISALLOW_COPY_AND_ASSIGN(ThreadPool); }; - -ThreadPool::ThreadPool(size_t size, std::function func) - : worker_threads_(size) { - for (auto& i : worker_threads_) { - i = std::thread(func); - } -} - -ThreadPool::~ThreadPool() noexcept(false) { - for (auto&& i : worker_threads_) { - i.join(); - } -} - } // namespace engine } // namespace mxnet #endif // MXNET_ENGINE_THREAD_POOL_H_ diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 09c12da13938..39681130e43d 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "./threaded_engine.h" #include "./thread_pool.h" #include "./stream_manager.h" @@ -92,7 +93,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { */ inline ThreadWorkerBlock *GetGPUWorkerBlock(size_t dev_id, FnProperty prop) { - bool is_copy = (prop == FnProperty::kCopy); + bool is_copy = (prop == FnProperty::kCopyFromGPU || + prop == FnProperty::kCopyToGPU); CHECK_LT(dev_id, kMaxNumGPUs) << "GPU Device index " << dev_id << " exceed bound " << kMaxNumGPUs; @@ -130,7 +132,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { dmlc::ConcurrentBlockingQueue* task_queue) { #if MXNET_USE_CUDA // allocate stream - mshadow::SetDevice(dev_id); + mshadow::SetDevice(dev_id); RunContext run_ctx; mshadow::Stream *stream; if (is_copy_worker) { diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index e8027eeea1f9..0978b32ea8d6 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -86,7 +86,9 @@ class ThreadedEnginePooled : public ThreadedEngine { LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA } - auto&& rctx = opr_block->opr->prop == FnProperty::kCopy + bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU || + opr_block->opr->prop == FnProperty::kCopyToGPU); + auto&& rctx = is_copy ? streams_.GetIORunContext(opr_block->ctx) : streams_.GetRunContext(opr_block->ctx); this->ExecuteOprBlock(rctx, opr_block); @@ -97,7 +99,8 @@ class ThreadedEnginePooled : public ThreadedEngine { */ void DoPushToQueue(OprBlock* opr_block) { switch (opr_block->opr->prop) { - case FnProperty::kCopy: { + case FnProperty::kCopyFromGPU: + case FnProperty::kCopyToGPU: { io_task_queue_.Push(opr_block); break; } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index e9be7e445da6..feb3de61be2d 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -169,7 +169,7 @@ void CopyFromTo(const NDArray &from, NDArray *to) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); }, from.ctx(), const_vars, {ret.ptr_->var}); } else { #if MXNET_USE_CUDA @@ -178,28 +178,28 @@ void CopyFromTo(const NDArray &from, NDArray *to) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, ret.ctx(), const_vars, {ret.ptr_->var}); + }, ret.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyToGPU); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.ptr_->var}); + }, from.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyFromGPU); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.ptr_->var}); + }, from.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyFromGPU); } else { LOG(FATAL) << "unknown device mask"; } diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 498ec4f942e0..8e08952eb234 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -205,9 +205,6 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { } } - for (const Resource& r : op_node.op_ctx.requested) { - exec.mutate_vars.push_back(static_cast(r.var)); - } // start setup exec function. for (const Resource& r : op_node.op_ctx.requested) { exec.mutate_vars.push_back(static_cast(r.var)); From 294b35e4f2801bd70b7edaac7eca6f258c7ad5ba Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 14 Sep 2015 13:15:44 -0700 Subject: [PATCH 4/4] fix compile --- src/engine/engine.cc | 4 ++-- tests/cpp/threaded_engine_unittest.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/engine/engine.cc b/src/engine/engine.cc index 9698b490157a..d5d8d3aa7a1b 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -21,12 +21,12 @@ inline Engine* CreateEngine() { ret = CreateNaiveEngine(); } else if (stype == "ThreadedEngine") { ret = CreateThreadedEnginePooled(); - } else if (stype == "ThreadedEnginePerDevie") { + } else if (stype == "ThreadedEnginePerDevice") { ret = CreateThreadedEnginePerDevice(); } CHECK_NE(ret, nullptr) - << "Cannot find Eine " << type << " in registry"; + << "Cannot find Engine " << type; if (!default_engine) { LOG(INFO) << "MXNet start using engine: " << type; } diff --git a/tests/cpp/threaded_engine_unittest.cc b/tests/cpp/threaded_engine_unittest.cc index 35e0ca3124b0..ffe3ee4ad3da 100644 --- a/tests/cpp/threaded_engine_unittest.cc +++ b/tests/cpp/threaded_engine_unittest.cc @@ -72,7 +72,7 @@ TEST(Engine, basics) { Foo(ctx, 42); cb(); }, - {}, {var}, mxnet::FnProperty::kCopy)); + {}, {var}, mxnet::FnProperty::kCopyFromGPU)); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "IO operator pushed, should wait for 2 seconds."; engine->WaitForVar(var);