diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 043ad050d23f..f6122c1f37b6 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -52,6 +52,16 @@ typedef void *DataIterHandle; * \return error info */ MXNET_DLL const char *MXGetLastError(); + +//------------------------------------- +// Part 0: Global State setups +//------------------------------------- +/*! + * \brief Seed the global random number generators in mxnet. + * \param seed the random number seed. + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXRandomSeed(int seed); //------------------------------------- // Part 1: NDArray creation and deletion //------------------------------------- diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 0db270fbb958..72a4456f592a 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -8,6 +8,7 @@ #include #if DMLC_USE_CXX11 +#include #include #endif #include @@ -154,6 +155,15 @@ class Engine { * \return Engine singleton. */ static Engine* Get(); + /*! + * \brief Get shared pointer reference to engine singleton. + * Most user should not call this function. + * This function is called by another singleton X who requires + * engine to be destructed after X. + * + * \return A shared pointer to Engine singleton. + */ + static std::shared_ptr _GetSharedRef(); /*! * \brief Push an synchronous operation to the engine. * \param exec_fn Execution function that executes the operation. diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 15747a9bda02..047f9723916f 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -237,6 +237,13 @@ class NDArray { ret.shape_ = shape; return ret; } + /*! + * \brief Allocate the space if it is delayed allocated. + * This is an internal function used by system that normal user should not use + */ + inline void CheckAndAlloc() const { + ptr_->CheckAndAlloc(); + } private: /*! \brief the real data chunk that backs NDArray */ @@ -299,16 +306,6 @@ class NDArray { TShape shape_; /*! \brief offset in chunk */ size_t offset_; - - // add friend to helper functions - friend void CopyFromTo(const NDArray &from, NDArray *to); - template - friend void BinaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out); - template - friend void UnaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out); - template - friend void ScalarOp(const NDArray &lhs, const real_t &rhs, NDArray *out); - friend void SetValueOp(const real_t &rhs, NDArray *out); }; /*! @@ -380,6 +377,27 @@ NDArray operator/(const NDArray &lhs, const NDArray &rhs); */ NDArray operator/(const NDArray &lhs, const real_t &rhs); +/*! + * \brief Seed the random number generator. + * \param seed the seed to set to global random number generators. + */ +void RandomSeed(uint32_t seed); +/*! + * \brief Sample uniform distribution for each elements of out. + * \param begin lower bound of distribution. + * \param end upper bound of distribution. + * \param out output NDArray. + */ +void SampleUniform(real_t begin, real_t end, NDArray *out); + +/*! + * \brief Sample gaussian distribution for each elements of out. + * \param mu mean of gaussian distribution. + * \param sigma standard deviation of gaussian distribution. + * \param out output NDArray. + */ +void SampleGaussian(real_t mu, real_t sigma, NDArray *out); + //-------------------------------------------------------------- // The following part are API Registration of NDArray functions. //-------------------------------------------------------------- @@ -430,14 +448,12 @@ struct NDArrayFunctionReg * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void fsetvalue(const real_t &rhs, - NDArray *out)) { - body = [fsetvalue] (NDArray **used_vars, - real_t *s, NDArray **mutate_vars) { + NDArray *out)) { + body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { fsetvalue(s[0], mutate_vars[0]); }; num_mutate_vars = 1; num_scalars = 1; - // type_mask = kNDArrayArgBeforeScalar; - this->add_argument("rhs", "real_t", "Right operand to the function."); + this->add_argument("src", "real_t", "Source input to the function."); return *this; } /*! @@ -447,8 +463,8 @@ struct NDArrayFunctionReg * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void fbinary(const NDArray &lhs, - const NDArray &rhs, - NDArray *out)) { + const NDArray &rhs, + NDArray *out)) { body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]); @@ -466,10 +482,10 @@ struct NDArrayFunctionReg * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void fscalar(const NDArray &lhs, - const real_t &rhs, - NDArray *out)) { + const real_t &rhs, + NDArray *out)) { body = [fscalar] (NDArray **used_vars, - real_t *s, NDArray **mutate_vars) { + real_t *s, NDArray **mutate_vars) { fscalar(*used_vars[0], s[0], mutate_vars[0]); }; num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1; @@ -485,7 +501,7 @@ struct NDArrayFunctionReg * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void funary(const NDArray &src, - NDArray *out)) { + NDArray *out)) { body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { funary(*used_vars[0], mutate_vars[0]); diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h index 8d03b08ad44a..a6c61f6f8862 100644 --- a/include/mxnet/resource.h +++ b/include/mxnet/resource.h @@ -55,13 +55,18 @@ struct Resource { void *ptr_; /*! * \brief Get random number generator. + * \param The stream to use in the random number generator. * \return the mshadow random number generator requested. * \tparam xpu the device type of random number generator. */ template - inline mshadow::Random* get_random() const { + inline mshadow::Random* get_random( + mshadow::Stream *stream) const { CHECK_EQ(req.type, ResourceRequest::kRandom); - return static_cast*>(ptr_); + mshadow::Random *ret = + static_cast*>(ptr_); + ret->set_stream(stream); + return ret; } /*! * \brief Get space requested as mshadow Tensor. @@ -81,5 +86,31 @@ struct Resource { static_cast(ptr_), shape, shape[ndim - 1], stream); } }; + +/*! \brief Global resource manager */ +class ResourceManager { + public: + /*! + * \brief Get resource of requested type. + * \param ctx the context of the request. + * \param req the resource request. + * \return the requested resource. + * \note The returned resource's ownership is + * still hold by the manager singleton. + * + */ + virtual Resource Request(Context ctx, const ResourceRequest &req) = 0; + /*! + * \brief Seed all the allocated random numbers. + * \param seed the seed to the random number generators on all devices. + */ + virtual void SeedRandom(uint32_t seed) = 0; + /*! \brief virtual destructor */ + virtual ~ResourceManager() {} + /*! + * \return Resource manager singleton. + */ + static ResourceManager *Get(); +}; } // namespace mxnet #endif // MXNET_RESOURCE_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 5590c9f1cdad..71d303ff01f3 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -53,6 +53,15 @@ class Storage { * \return Storage singleton. */ static Storage* Get(); + /*! + * \brief Get shared pointer reference to engine singleton. + * Most user should not call this function. + * This function is called by another singleton X who requires + * Storage to be destructed after X. + * + * \return A shared pointer to Storage singleton. + */ + static std::shared_ptr _GetSharedRef(); private: /*! diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 1417b1262505..8ccf04519e22 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -16,5 +16,6 @@ from . import io # use mx.nd as short for mx.ndarray from . import ndarray as nd +from . import random __version__ = "0.1.0" diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 707591f9fc19..1e5e4652aced 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -20,8 +20,12 @@ def __init__(self, device_type, device_id=0): device_id : int (default=0) the device id of the device, needed for GPU """ - self.device_mask = Context.devtype2mask[device_type] - self.device_id = device_id + if isinstance(device_type, Context): + self.device_mask = device_type.device_mask + self.device_id = device_type.device_id + else: + self.device_mask = Context.devtype2mask[device_type] + self.device_id = device_id self._old_ctx = None @property diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 5b9323639298..a9634292f9c6 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -340,8 +340,7 @@ def zeros(shape, ctx=None): ---------- shape : tuple shape of the NDArray. - - ctx : Context, optional + ctx : Context, optional. The context of the NDArray, default to current default context. Returns @@ -360,7 +359,6 @@ def ones(shape, ctx=None): ---------- shape : tuple shape of the NDArray. - ctx : Context, optional The context of the NDArray, default to current default context. diff --git a/python/mxnet/random.py b/python/mxnet/random.py new file mode 100644 index 000000000000..489a8bd16097 --- /dev/null +++ b/python/mxnet/random.py @@ -0,0 +1,99 @@ +# coding: utf-8 +# pylint: disable=no-member, protected-access +"""Random Number interface of mxnet.""" +from __future__ import absolute_import + +import ctypes +from .base import _LIB, check_call +from .ndarray import NDArray, empty + + +def uniform(low, high, shape=None, ctx=None, out=None): + """Generate uniform distribution in [low, high) with shape. + + Parameters + ---------- + low : float + The lower bound of distribution. + high : float + The upper bound of distribution. + shape : tuple, optional + Output shape of the NDArray generated. + ctx : Context, optional + Context of output NDArray, will use default context if not specified. + out : NDArray, optional + Output place holder + + Returns + ------- + out : NDArray + The result NDArray with generated result. + """ + if out is not None: + if shape is not None or ctx is not None: + raise ValueError('shape and ctx is not needed when out is specified') + else: + if shape is None: + raise ValueError('shape is required when out is not specified') + if isinstance(shape, int): + shape = (shape,) + out = empty(shape, ctx) + return NDArray._random_uniform(low, high, out=out) + + +def normal(mean, stdvar, shape=None, ctx=None, out=None): + """Generate normal(Gaussian) distribution N(mean, stdvar^2) with shape. + + Parameters + ---------- + mean : float + The mean of the normal distribution. + stdvar : float + The standard deviation of normal distribution. + shape : tuple, optional + Output shape of the NDArray generated. + ctx : Context, optional + Context of output NDArray, will use default context if not specified. + out : NDArray, optional + Output place holder + + Returns + ------- + out : NDArray + The result NDArray with generated result. + """ + if out is not None: + if shape is not None or ctx is not None: + raise ValueError('shape and ctx is not needed when out is specified') + else: + if shape is None: + raise ValueError('shape is required when out is not specified') + if isinstance(shape, int): + shape = (shape,) + out = empty(shape, ctx) + return NDArray._random_gaussian(mean, stdvar, out=out) + + +def seed(seed_state): + """Seed the random number generators in mxnet. + + This seed will affect behavior of functions in this module, + as well as results from executors that contains Random number + such as Dropout operators. + + Parameters + ---------- + seed_state : int + The random number seed to set to all devices. + + Notes + ----- + The random number generator of mxnet is by default device specific. + This means if you set the same seed, the random number sequence + generated from GPU0 can be different from CPU. + """ + if not isinstance(seed_state, int): + raise ValueError('sd must be int') + seed_state = ctypes.c_int(int(seed_state)) + check_call(_LIB.MXRandomSeed(seed_state)) + diff --git a/src/c_api.cc b/src/c_api.cc index 6427a6357c90..222440f76855 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -181,6 +181,12 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, } // NOTE: return value is added in API_END +int MXRandomSeed(int seed) { + API_BEGIN(); + mxnet::RandomSeed(seed); + API_END(); +} + int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); *out = new NDArray(); diff --git a/src/common/lazy_alloc_array.h b/src/common/lazy_alloc_array.h new file mode 100644 index 000000000000..28c4c259c603 --- /dev/null +++ b/src/common/lazy_alloc_array.h @@ -0,0 +1,106 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file lazy_alloc_array.h + * \brief An array that lazily allocate elements as + * First time the cell get visited. + */ +#ifndef MXNET_COMMON_LAZY_ALLOC_ARRAY_H_ +#define MXNET_COMMON_LAZY_ALLOC_ARRAY_H_ + +#include +#include +#include +#include +#include + +namespace mxnet { +namespace common { + +template +class LazyAllocArray { + public: + /*! + * \brief Get element of corresponding index, + * if it is not created create by creator + * \param index the array index position + * \param creator a lambda function to create new element when needed. + */ + template + inline TElem* Get(int index, FCreate creator); + /*! + * \brief for each not null element of the array, call fvisit + * \param fviist a function of (size_t, TElem*) + */ + template + inline void ForEach(FVisit fvisit); + /*! \brief clear all the allocated elements in array */ + inline void Clear(); + + private: + /*! \brief the initial size of the array */ + static constexpr std::size_t kInitSize = 16; + /*! \brief mutex used during creation */ + std::mutex create_mutex_; + /*! \brief internal data fir initial size */ + std::array, kInitSize> head_; + /*! \brief overflow array of more elements */ + std::vector > more_; +}; + +// implementations +template +template +inline TElem* LazyAllocArray::Get(int index, FCreate creator) { + CHECK_GE(index, 0); + size_t idx = static_cast(index); + if (idx < kInitSize) { + TElem *ptr = head_[idx].get(); + if (ptr != nullptr) { + return ptr; + } else { + std::lock_guard lock(create_mutex_); + TElem *ptr = head_[idx].get(); + if (ptr != nullptr) return ptr; + head_[idx].reset(ptr = creator()); + return ptr; + } + } else { + std::lock_guard lock(create_mutex_); + idx -= kInitSize; + if (more_.size() <= idx) more_.resize(idx + 1); + TElem *ptr = more_[idx].get(); + if (ptr != nullptr) return ptr; + more_[idx].reset(ptr = creator()); + return ptr; + } +} + +template +inline void LazyAllocArray::Clear() { + std::lock_guard lock(create_mutex_); + for (size_t i = 0; i < head_.size(); ++i) { + head_[i].reset(nullptr); + } + for (size_t i = 0; i < more_.size(); ++i) { + more_[i].reset(nullptr); + } +} + +template +template +inline void LazyAllocArray::ForEach(FVisit fvisit) { + std::lock_guard lock(create_mutex_); + for (size_t i = 0; i < head_.size(); ++i) { + if (head_[i].get() != nullptr) { + fvisit(i, head_[i].get()); + } + } + for (size_t i = 0; i < more_.size(); ++i) { + if (more_[i].get() != nullptr) { + fvisit(i + kInitSize, more_[i].get()); + } + } +} +} // namespace common +} // namespace mxnet +#endif // MXNET_COMMON_LAZY_ALLOC_ARRAY_H_ diff --git a/src/engine/engine.cc b/src/engine/engine.cc index 6cd82f96050b..70bd18c442b7 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -34,8 +34,13 @@ inline Engine* CreateEngine() { } } // namespace engine +std::shared_ptr Engine::_GetSharedRef() { + static std::shared_ptr sptr(engine::CreateEngine()); + return sptr; +} + Engine* Engine::Get() { - static std::unique_ptr inst(engine::CreateEngine()); - return inst.get(); + static Engine *inst = _GetSharedRef().get(); + return inst; } } // namespace mxnet diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 8e59c59ab30e..1a3144e783ec 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -4,13 +4,13 @@ * \brief implements base threaded engine. * \author Yutian Li */ -#include "threaded_engine.h" #include #include #include #include #include #include +#include "./threaded_engine.h" #include "../common/cuda_utils.h" namespace mxnet { diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index c67fa5431613..fa29939d291f 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -12,8 +12,8 @@ #include #include #include -#include #include +#include #include #include "./engine_impl.h" #include "../common/object_pool.h" diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 0a3da50e69be..fada76801250 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -10,7 +10,7 @@ #include #include "./threaded_engine.h" #include "./thread_pool.h" -#include "./stream_manager.h" +#include "../common/lazy_alloc_array.h" namespace mxnet { namespace engine { @@ -38,6 +38,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { // GPU tasks will be created lazily } ~ThreadedEnginePerDevice() noexcept(false) { + // wait until all the tasks are completed. + this->WaitForAll(); } protected: @@ -82,48 +84,35 @@ class ThreadedEnginePerDevice : public ThreadedEngine { int gpu_worker_nthreads_; /*! \brief number of concurrent thread each gpu copy worker uses */ int gpu_copy_nthreads_; - // mutex used when creating a ThreadWorkerBlock - std::mutex create_mutex_; // cpu worker ThreadWorkerBlock cpu_worker_; // workers doing normal works on GPU - std::array, kMaxNumGPUs> gpu_normal_workers_; + common::LazyAllocArray gpu_normal_workers_; // workers doing copy works from/to GPU - std::array, kMaxNumGPUs> gpu_copy_workers_; + common::LazyAllocArray gpu_copy_workers_; /*! * \brief get GPU Task Worker * \param dev_id the device id * \param prop The property of the function. */ - inline ThreadWorkerBlock *GetGPUWorkerBlock(size_t dev_id, + inline ThreadWorkerBlock *GetGPUWorkerBlock(int dev_id, FnProperty prop) { bool is_copy = (prop == FnProperty::kCopyFromGPU || prop == FnProperty::kCopyToGPU); - CHECK_LT(dev_id, kMaxNumGPUs) - << "GPU Device index " << dev_id - << " exceed bound " << kMaxNumGPUs; - std::array, kMaxNumGPUs> *workers; + auto *arr = &gpu_normal_workers_; + int nthread = gpu_worker_nthreads_; if (is_copy) { - workers = &gpu_copy_workers_; - } else { - workers = &gpu_normal_workers_; - } - ThreadWorkerBlock *block = workers->at(dev_id).get(); - if (block != nullptr) return block; - { - // only lock when block is not available. - std::lock_guard lock(create_mutex_); - // need to double check, because state can change - ThreadWorkerBlock *block = workers->at(dev_id).get(); - if (block != nullptr) return block; - int nthread = is_copy ? gpu_copy_nthreads_ : gpu_worker_nthreads_; - workers->at(dev_id).reset(new ThreadWorkerBlock()); - block = workers->at(dev_id).get(); - block->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, block] () { - this->GPUWorker(dev_id, is_copy, &(block->task_queue)); - })); - return block; + arr = &gpu_copy_workers_; + nthread = gpu_copy_nthreads_; } + + return arr->Get(dev_id, [this, dev_id, is_copy, nthread]() { + auto block = new ThreadWorkerBlock(); + block->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, block] () { + this->GPUWorker(dev_id, is_copy, &(block->task_queue)); + })); + return block; + }); } /*! * \brief GPU worker that performs operations on a certain device. diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 0978b32ea8d6..8ab7092dfce9 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -27,6 +27,9 @@ class ThreadedEnginePooled : public ThreadedEngine { io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {} ~ThreadedEnginePooled() noexcept(false) { + // wait until all the tasks are completed. + // TODO(hotpxl) think if this is the correct thing to do + this->WaitForAll(); task_queue_.SignalForKill(); io_task_queue_.SignalForKill(); } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index feb3de61be2d..920fe1bed709 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "./ndarray_function.h" @@ -45,28 +46,28 @@ inline void BinaryOp(const NDArray &lhs, NDArray ret = *out; // get the const variables std::vector const_vars; - if (lhs.ptr_->var != ret.ptr_->var) const_vars.push_back(lhs.ptr_->var); - if (rhs.ptr_->var != ret.ptr_->var) const_vars.push_back(rhs.ptr_->var); + if (lhs.var() != ret.var()) const_vars.push_back(lhs.var()); + if (rhs.var() != ret.var()) const_vars.push_back(rhs.var()); // redirect everything to mshadow operations switch (lhs.ctx().dev_mask) { case cpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.ptr_->var}); + }, lhs.ctx(), const_vars, {ret.var()}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, lhs.ctx(), const_vars, {ret.ptr_->var}); + }, lhs.ctx(), const_vars, {ret.var()}); break; } #endif @@ -81,21 +82,21 @@ inline void SetValueOp(const real_t &rhs, NDArray *out) { switch (ret.ctx().dev_mask) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); - }, ret.ctx(), {}, {ret.ptr_->var}); + }, ret.ctx(), {}, {ret.var()}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, ret.ctx(), {}, {ret.ptr_->var}); + }, ret.ctx(), {}, {ret.var()}); break; } #endif @@ -123,27 +124,27 @@ inline void ScalarOp(const NDArray &lhs, NDArray ret = *out; // get the const variables std::vector const_vars; - if (lhs.ptr_->var != ret.ptr_->var) const_vars.push_back(lhs.ptr_->var); + if (lhs.var() != ret.var()) const_vars.push_back(lhs.var()); // redirect everything to mshadow operations switch (lhs.ctx().dev_mask) { case cpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs, &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.ptr_->var}); + }, lhs.ctx(), const_vars, {ret.var()}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, lhs.ctx(), const_vars, {ret.ptr_->var}); + }, lhs.ctx(), const_vars, {ret.var()}); break; } #endif @@ -162,44 +163,44 @@ void CopyFromTo(const NDArray &from, NDArray *to) { int b = to->ctx().dev_mask; std::vector const_vars; - if (from.ptr_->var != ret.ptr_->var) const_vars.push_back(from.ptr_->var); + if (from.var() != ret.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); - }, from.ctx(), const_vars, {ret.ptr_->var}); + }, from.ctx(), const_vars, {ret.var()}); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, ret.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyToGPU); + }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyFromGPU); + }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); + ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyFromGPU); + }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU); } else { LOG(FATAL) << "unknown device mask"; } @@ -209,6 +210,54 @@ void CopyFromTo(const NDArray &from, NDArray *to) { } } + +template +inline void SampleOP(const real_t &a, + const real_t &b, + NDArray *out) { + CHECK(!out->is_none()); + Resource resource = ResourceManager::Get()->Request( + out->ctx(), ResourceRequest::kRandom); + // important: callback must always capture by value + NDArray ret = *out; + // redirect everything to mshadow operations + switch (out->ctx().dev_mask) { + case cpu::kDevMask: { + Engine::Get()->PushSync([a, b, resource, ret](RunContext ctx) { + ret.CheckAndAlloc(); + TBlob tmp = ret.data(); + ndarray::EvalRandom(a, b, resource, &tmp, ctx); + }, out->ctx(), {}, {ret.var(), resource.var}); + break; + } +#if MXNET_USE_CUDA + case gpu::kDevMask: { + Engine::Get()->PushSync([a, b, resource, ret](RunContext ctx) { + ret.CheckAndAlloc(); + TBlob tmp = ret.data(); + ndarray::EvalRandom(a, b, resource, &tmp, ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, out->ctx(), {}, {ret.var(), resource.var}); + break; + } +#endif + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + } +} + +void SampleUniform(real_t begin, real_t end, NDArray *out) { + SampleOP(begin, end, out); +} + +void SampleGaussian(real_t mu, real_t sigma, NDArray *out) { + SampleOP(mu, sigma, out); +} + +void RandomSeed(uint32_t seed) { + ResourceManager::Get()->SeedRandom(seed); +} + template inline NDArray BinaryOpRet(const NDArray &lhs, const NDArray &rhs) { @@ -423,16 +472,13 @@ MXNET_REGISTER_NDARRAY_FUN(_div).set_function(BinaryOp); // register API function // those with underscore will be registered at NDArray -// scalar MXNET_REGISTER_NDARRAY_FUN(_plus_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_minus_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_mul_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_div_scalar).set_function(ScalarOp); // register API function -// those with underscore will be registered at NDArray -// scalar -// reverse scalar +// scalar, reverse scalar MXNET_REGISTER_NDARRAY_FUN(_rminus_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_rdiv_scalar).set_function(ScalarOp); @@ -442,4 +488,18 @@ MXNET_REGISTER_NDARRAY_FUN(_copyto) .set_function(CopyFromTo) .set_type_mask(kNDArrayArgBeforeScalar); +// register random number generators +MXNET_REGISTER_NDARRAY_FUN(_random_uniform) +.set_body([](NDArray **u, real_t *s, NDArray **out) { + SampleUniform(s[0], s[1], out[0]); + }) +.set_num_scalars(2) +.set_num_mutate_vars(1); + +MXNET_REGISTER_NDARRAY_FUN(_random_gaussian) +.set_body([](NDArray **u, real_t *s, NDArray **out) { + SampleGaussian(s[0], s[1], out[0]); + }) +.set_num_scalars(2) +.set_num_mutate_vars(1); } // namespace mxnet diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index 6494d64a148e..34a81af1bb39 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -24,16 +24,6 @@ } #endif -#ifndef DECL_SETVALUE -#define DECL_SETVALUE(XPU) \ - template<> \ - void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { \ - mshadow::Stream *s = static_cast*>(ctx.stream); \ - ret->FlatTo2D(s) = rhs; \ - } -#endif - - #if defined(__CUDACC__) #define DEVICE gpu #else @@ -45,9 +35,9 @@ namespace ndarray { // true implementation template inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; - mshadow::Stream *s = static_cast*>(ctx.stream); + mshadow::Stream *s = ctx.get_stream(); ret->FlatTo2D(s) = F(lhs.FlatTo2D(s), rhs.FlatTo2D(s)); @@ -57,7 +47,7 @@ template inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; - mshadow::Stream *s = static_cast*>(ctx.stream); + mshadow::Stream *s = ctx.get_stream(); if (reverse) { ret->FlatTo2D(s) = F(rhs, lhs.FlatTo2D(s)); @@ -67,6 +57,39 @@ inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, } } +template<> +void EvalRandom( + const real_t &a, + const real_t &b, + const Resource &resource, + TBlob *ret, + RunContext ctx) { + typedef DEVICE xpu; + mshadow::Stream *s = ctx.get_stream(); + mshadow::Tensor tmp = ret->FlatTo2D(s); + mshadow::Random *prnd = resource.get_random(s); + prnd->SampleUniform(&tmp, a, b); +} + +template<> +void EvalRandom( + const real_t &mu, + const real_t &sigma, + const Resource &resource, + TBlob *ret, + RunContext ctx) { + typedef DEVICE xpu; + mshadow::Stream *s = ctx.get_stream(); + mshadow::Tensor tmp = ret->FlatTo2D(s); + mshadow::Random *prnd = resource.get_random(s); + prnd->SampleGaussian(&tmp, mu, sigma); +} + +template<> +void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { + mshadow::Stream *s = ctx.get_stream(); + ret->FlatTo2D(s) = rhs; +} // declarations DECL_BINARY(DEVICE, Plus, EvalBinary_) @@ -82,8 +105,6 @@ DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) DECL_SCALAR(DEVICE, Mul, EvalScalar_, false) DECL_SCALAR(DEVICE, Div, EvalScalar_, false) -// -DECL_SETVALUE(DEVICE) } // namespace ndarray } // namespace mxnet diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 0a0dc89ccdde..a54766c75002 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace mxnet { /*! \brief namespace to support all possible Ndarray operator */ @@ -23,15 +24,25 @@ struct BinaryBase { struct Plus : public BinaryBase { typedef mshadow::op::plus mshadow_op; }; + struct Minus : public BinaryBase { typedef mshadow::op::minus mshadow_op; }; + struct Mul : public BinaryBase { typedef mshadow::op::mul mshadow_op; }; + struct Div : public BinaryBase { typedef mshadow::op::div mshadow_op; }; + +// type holder for random number generators +struct UniformDistribution {}; + +struct GaussianDistribution {}; + + template void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); @@ -41,6 +52,12 @@ void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx); template void Eval(const real_t &rhs, TBlob *ret, RunContext ctx); +template +void EvalRandom(const real_t &a, + const real_t &b, + const Resource &resource, + TBlob *ret, RunContext ctx); + // copy function when only cpu is involved template void Copy(const TBlob &from, TBlob *to, diff --git a/src/resource.cc b/src/resource.cc new file mode 100644 index 000000000000..62d0b0080fc3 --- /dev/null +++ b/src/resource.cc @@ -0,0 +1,121 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file resource.cc + * \brief Implementation of resource manager. + */ +#include +#include +#include +#include +#include "./common/lazy_alloc_array.h" + +namespace mxnet { +namespace resource { + +// implements resource manager +class ResourceManagerImpl : public ResourceManager { + public: + ResourceManagerImpl() : global_seed_(0) { + engine_ref_ = Engine::_GetSharedRef(); + cpu_rand_ = new ResourceRandom( + Context(cpu::kDevMask, 0), global_seed_); + } + ~ResourceManagerImpl() { + // need explicit delete, before engine get killed + delete cpu_rand_; +#if MXNET_USE_CUDA + gpu_rand_.Clear(); +#endif + // release the reference to engine. + engine_ref_ = nullptr; + } + + // request resources + Resource Request(Context ctx, const ResourceRequest &req) override { + if (req.type == ResourceRequest::kRandom) { + if (ctx.dev_mask == cpu::kDevMask) { + return cpu_rand_->resource; + } else { + CHECK_EQ(ctx.dev_mask, gpu::kDevMask); +#if MSHADOW_USE_CUDA + return gpu_rand_.Get(ctx.dev_id, [ctx, this]() { + return new ResourceRandom(ctx, global_seed_); + })->resource; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } + } else { + LOG(FATAL) << "Unknown supported type " << req.type; + } + Resource ret; + return ret; + } + + void SeedRandom(uint32_t seed) override { + global_seed_ = seed; + cpu_rand_->Seed(global_seed_); +#if MXNET_USE_CUDA + gpu_rand_.ForEach([seed](size_t i, ResourceRandom *p) { + p->Seed(seed); + }); +#endif + } + + private: + /*! \brief Maximum number of GPUs */ + static constexpr std::size_t kMaxNumGPUs = 16; + /*! \brief Random number magic number to seed different random numbers */ + static constexpr uint32_t kRandMagic = 127UL; + /*! \brief Reference to the engine */ + std::shared_ptr engine_ref_; + + // the random number resources + template + struct ResourceRandom { + /*! \brief pointer to PRNG */ + mshadow::Random *prnd; + /*! \brief the context of the PRNG */ + Context ctx; + /*! \brief resource representation */ + Resource resource; + /*! \brief constructor */ + explicit ResourceRandom(Context ctx, uint32_t global_seed) + : ctx(ctx) { + mshadow::SetDevice(ctx.dev_id); + resource.var = Engine::Get()->NewVariable(); + prnd = new mshadow::Random(ctx.dev_id + global_seed * kRandMagic); + resource.ptr_ = prnd; + resource.req = ResourceRequest(ResourceRequest::kRandom); + } + ~ResourceRandom() { + mshadow::Random *r = prnd; + Engine::Get()->DeleteVariable( + [r](RunContext rctx){ delete r; }, ctx, resource.var); + } + // set seed to a PRNG + inline void Seed(uint32_t global_seed) { + uint32_t seed = ctx.dev_id + global_seed * kRandMagic; + mshadow::Random *r = prnd; + Engine::Get()->PushSync([r, seed](RunContext rctx) { + r->set_stream(rctx.get_stream()); + r->Seed(seed); + }, ctx, {}, {resource.var}); + } + }; + /*! \brief internal seed to the random number generator */ + uint32_t global_seed_; + /*! \brief CPU random number resources */ + ResourceRandom *cpu_rand_; +#if MXNET_USE_CUDA + /*! \brief random number generator for GPU */ + common::LazyAllocArray > gpu_rand_; +#endif +}; +} // namespace resource + +ResourceManager* ResourceManager::Get() { + static resource::ResourceManagerImpl inst; + return &inst; +} +} // namespace mxnet diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 1b0fb0dff528..cdcdbc15840b 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -1,11 +1,12 @@ /*! * Copyright (c) 2015 by Contributors */ -#include "mxnet/storage.h" +#include #include #include #include #include +#include #include "storage_manager.h" #include "naive_storage_manager.h" #include "pooled_storage_manager.h" @@ -94,10 +95,14 @@ void Storage::Free(Storage::Handle handle) { Storage::~Storage() = default; +std::shared_ptr Storage::_GetSharedRef() { + static std::shared_ptr inst(new Storage()); + return inst; +} + Storage* Storage::Get() { - // This function is thread-safe in C++11 - static Storage inst; - return &inst; + static Storage *ptr = _GetSharedRef().get(); + return ptr; } Storage::Storage() : impl_{new Impl{}} {} diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index c1587a3657b4..a30a4c0d897c 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -208,7 +208,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { // start setup exec function. for (const Resource& r : op_node.op_ctx.requested) { - exec.mutate_vars.push_back(static_cast(r.var)); + exec.mutate_vars.push_back(r.var); } Operator* op = op_node.op.get(); @@ -471,6 +471,9 @@ void GraphExecutor::InitDataEntryMemory() { mshadow::Shape1(entry.resource.req.space_num_reals)); entry.resource.ptr_ = entry.data.data().dptr_; entry.resource.var = entry.data.var(); + } else if (entry.resource.req.type == ResourceRequest::kRandom) { + entry.resource = ResourceManager::Get()->Request( + op_nodes_[nid].ctx, entry.resource.req); } else { LOG(FATAL) << "resource type not yet supported"; } diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py new file mode 100644 index 000000000000..10be569e8f76 --- /dev/null +++ b/tests/python/unittest/test_random.py @@ -0,0 +1,32 @@ +import os +import mxnet as mx +import numpy as np + +def same(a, b): + return np.sum(a != b) == 0 + +def check_with_device(device): + with mx.Context(device): + a, b = -10, 10 + mu, sigma = 10, 2 + shape = (100, 100) + mx.random.seed(128) + ret1 = mx.random.normal(mu, sigma, shape) + un1 = mx.random.uniform(a, b, shape) + mx.random.seed(128) + ret2 = mx.random.normal(mu, sigma, shape) + un2 = mx.random.uniform(a, b, shape) + assert same(ret1.asnumpy(), ret2.asnumpy()) + assert same(un1.asnumpy(), un2.asnumpy()) + assert abs(np.mean(ret1.asnumpy()) - mu) < 0.1 + assert abs(np.std(ret1.asnumpy()) - sigma) < 0.1 + assert abs(np.mean(un1.asnumpy()) - (a+b)/2) < 0.1 + + +def test_random(): + check_with_device(mx.cpu()) + + +if __name__ == '__main__': + test_random() +