Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif

#ENGINE=simple_engine.o dag_engine.o
ENGINE=naive_engine.o
BIN = tests/test_simple_engine
OBJ = narray_function_cpu.o
OBJCXX11 = batch_norm_cpu.o reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o
OBJCXX11 = batch_norm_cpu.o reshape_cpu.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o $(ENGINE)
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
Expand All @@ -81,8 +83,9 @@ $(DMLC_CORE)/libdmlc.a:
+ cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR)

storage.o: src/storage/storage.cc
naive_engine.o: src/dag_engine/naive_engine.cc
dag_engine.o: src/dag_engine/dag_engine.cc
simple_engine.o: src/dag_engine/simple_engine.cc
simple_engine.o: src/dag_engine/simple_engine.cc
narray.o: src/narray/narray.cc
narray_function_cpu.o: src/narray/narray_function.cc src/narray/narray_function-inl.h
narray_function_gpu.o: src/narray/narray_function.cu src/narray/narray_function-inl.h
Expand Down Expand Up @@ -120,10 +123,10 @@ $(BIN) :
$(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)

$(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
$(CXX) -c $(CFLAGS) -o $@ $(filter %.cpp %.c %.cc, $^)

$(OBJCXX11) :
$(CXX) -std=c++0x -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
$(CXX) -std=c++0x -c $(CFLAGS) -o $@ $(filter %.cpp %.c %.cc, $^)

$(SLIB) :
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)
Expand Down
6 changes: 3 additions & 3 deletions doc/Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ LOOKUP_CACHE_SIZE = 0
# normally produced when WARNINGS is set to YES.
# The default value is: NO.

EXTRACT_ALL = NO
EXTRACT_ALL = YES

# If the EXTRACT_PRIVATE tag is set to YES all private members of a class will
# be included in the documentation.
Expand Down Expand Up @@ -2088,7 +2088,7 @@ HIDE_UNDOC_RELATIONS = YES
# set to NO
# The default value is: YES.

HAVE_DOT = NO
HAVE_DOT = YES

# The DOT_NUM_THREADS specifies the number of dot invocations doxygen is allowed
# to run in parallel. When set to 0 doxygen will base this on the number of
Expand Down Expand Up @@ -2154,7 +2154,7 @@ GROUP_GRAPHS = YES
# The default value is: NO.
# This tag requires that the tag HAVE_DOT is set to YES.

UML_LOOK = NO
UML_LOOK = YES

# If the UML_LOOK tag is enabled, the fields and methods are shown inside the
# class node. If there are many fields or methods and many nodes the graph may
Expand Down
17 changes: 5 additions & 12 deletions include/mxnet/dag_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ class DAGEngine {
* mutate.
* \param mutate_vars The variables that current operation will mutate.
*/
void Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars);
virtual void Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) = 0;
/*!
* \brief Push an asynchronous operation to the DAG engine.
* \param exec_fun Execution function, this function takes a parameter
Expand Down Expand Up @@ -141,20 +141,13 @@ class DAGEngine {
/*!
* \brief Virtual destructor.
*/
virtual ~DAGEngine() noexcept(false);
virtual ~DAGEngine() noexcept(false) {}
/*!
* \return DAG engine singleton.
*/
static DAGEngine* Get();

protected:
/*!
* \brief Hidden constructors.
*/
DAGEngine();

private:
DISALLOW_COPY_AND_ASSIGN(DAGEngine);
// remove DISALLOW_COPY_AND_ASSIGN since this is virtual class.
}; // class DAGEngine

} // namespace mxnet
Expand Down
49 changes: 46 additions & 3 deletions python/mxnet/narray.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def __add__(self, other):
else:
raise TypeError('type %s not supported' % str(type(other)))

def __iadd__(self, other):
if isinstance(other, NArray):
return NArray._plus(self, other, out=self)
elif isinstance(other, float) or isinstance(other, int):
return NArray._plus_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __radd__(self, other):
return self.__add__(other)

Expand All @@ -82,6 +90,14 @@ def __sub__(self, other):
else:
raise TypeError('type %s not supported' % str(type(other)))

def __isub__(self, other):
if isinstance(other, NArray):
return NArray._minus(self, other, out=self)
elif isinstance(other, float) or isinstance(other, int):
return NArray._minus_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __mul__(self, other):
if isinstance(other, NArray):
return NArray._mul(self, other)
Expand All @@ -90,6 +106,14 @@ def __mul__(self, other):
else:
raise TypeError('type %s not supported' % str(type(other)))

def __imul__(self, other):
if isinstance(other, NArray):
return NArray._mul(self, other, out=self)
elif isinstance(other, float) or isinstance(other, int):
return NArray._mul_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __rmul__(self, other):
return self.__mul__(other)

Expand All @@ -102,7 +126,12 @@ def __div__(self, other):
raise TypeError('type %s not supported' % str(type(other)))

def __idiv__(self, other):
return self.__div__(other)
if isinstance(other, NArray):
return NArray._div(self, other, out=self)
elif isinstance(other, float) or isinstance(other, int):
return NArray._div_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __truediv__(self, other):
return self.__div__(other)
Expand Down Expand Up @@ -130,6 +159,20 @@ def __setstate__(self, state):
state['handle'] = handle
self.__dict__.update(state)

def __setitem__(self, in_slice, value):
"""Set narray value"""
if in_slice.step != None:
raise Exception("Set NArray should use empty index array[:] = target_array")
if isinstance(value, NArray) == False:
raise TypeError('type %s not supported' % str(type(value)))
value.copyto(self)

def __getitem__(self, in_slice):
"""Get narray"""
if in_slice.step != None:
raise Exception("Set NArray should use empty index array[:] += value")
return self

def wait(self):
"""Wait until the data on current NArray is available."""
check_call(_LIB.MXNArrayWait(self.handle))
Expand Down Expand Up @@ -369,7 +412,7 @@ def binary_narray_function(lhs, rhs, out=None):
"""Internal binary function
"""
if out:
if isinstance(out, NArray):
if isinstance(out, NArray) == False:
raise TypeError('out must be NArray')
else:
if not accept_empty_mutate:
Expand All @@ -384,7 +427,7 @@ def binary_narray_function(lhs, rhs, out=None):
def unary_narray_function(src, out=None):
"""internal NArray function"""
if out:
if isinstance(out, NArray):
if isinstance(out, NArray) == False:
raise TypeError('out must be NArray')
else:
if not accept_empty_mutate:
Expand Down
14 changes: 0 additions & 14 deletions src/dag_engine/dag_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,6 @@

namespace mxnet {

void DAGEngine::Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) {
auto f = [exec_fun](RunContext ctx, Callback on_complete) {
exec_fun(ctx);
on_complete();
};
PushAsync(f, exec_ctx, use_vars, mutate_vars);
}

DAGEngine::~DAGEngine() noexcept(false) {}

DAGEngine::DAGEngine() = default;

DAGEngine* DAGEngine::Get() {
/*!
* \brief Change specific engine to use.
Expand Down
72 changes: 72 additions & 0 deletions src/dag_engine/naive_engine.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright by Contributors
#include <mxnet/dag_engine.h>

namespace mxnet {
namespace engine {

// The Naive engine interface
class NaiveEngine : public DAGEngine {
public:
Variable NewVar() override {
return nullptr;
}

OprHandle NewOperator(AsyncFn fn,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override {
LOG(FATAL) << "Not implemented";
return nullptr;
}

void DeleteOperator(OprHandle op) override {
LOG(FATAL) << "Not implemented";
}

void Push(OprHandle op, Context exec_ctx) override {
LOG(FATAL) << "Not implemented";
}

void Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override {
if (exec_ctx.dev_mask == gpu::kDevMask) {
ctx_.stream = &stream_;
#if MXNET_USE_CUDA
mshadow::SetDevice<gpu>(exec_ctx.dev_id);
exec_fun(ctx_);
#else
LOG(FATAL) << "GPU is not enabled";
#endif
} else {
exec_fun(ctx_);
}
}

void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override {
LOG(FATAL) << "Not implemented";
}

void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) override {
this->Push(delete_fun, exec_ctx, {}, {var});
}

void WaitForVar(Variable var) override {
}

void WaitForAll() override {
}

private:
RunContext ctx_;
mshadow::Stream<gpu> stream_;
};

} // namespace engine

DAGEngine* DAGEngine::Get() {
static mxnet::engine::NaiveEngine engine;
return &engine;
}
} // namespace mxnet
10 changes: 10 additions & 0 deletions src/dag_engine/simple_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ void SimpleEngine::DeleteOperator(OprHandle op) {
Push(func, Context{}, {}, deps);
}

void SimpleEngine::Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) {
auto f = [exec_fun](RunContext ctx, Callback on_complete) {
exec_fun(ctx);
on_complete();
};
PushAsync(f, exec_ctx, use_vars, mutate_vars);
}

void SimpleEngine::Push(OprHandle op, Context exec_ctx) {
auto&& simple_opr = SimpleOpr::CastFromBase(op);
auto&& opr_block = new OprBlock{};
Expand Down
4 changes: 3 additions & 1 deletion src/dag_engine/simple_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ class SimpleEngine final : public DAGEngine {
std::vector<Variable> const& mutate_vars) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx) override;
using DAGEngine::Push;
void Push(Fn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<Variable> const& use_vars,
std::vector<Variable> const& mutate_vars) override;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def CalAcc(out, label):
wd = 0.0004

def Update(grad, weight):
weight.numpy[:] -= lr * grad.numpy[:] / batch_size
weight[:] -= lr * grad / batch_size

block = list(zip(grad_narrays, arg_narrays))

Expand Down
3 changes: 2 additions & 1 deletion tests/python/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def CalAcc(out, label):
epoch = 9
lr = 0.1
wd = 0.0004

def Update(grad, weight):
weight.numpy[:] -= lr * grad.numpy[:] / batch_size
weight[:] -= lr * grad / batch_size

block = list(zip(grad_narrays, arg_narrays))

Expand Down