Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
3bc227b
Test input a graph.
zheng-da Mar 20, 2018
dfdd28e
Update foreach to execute the subgraph.
zheng-da Mar 21, 2018
42c735d
print inputs/outputs in foreach.
zheng-da Mar 21, 2018
a4d6a64
Remove print.
zheng-da Mar 23, 2018
037caa0
add test code for foreach.
zheng-da Mar 23, 2018
6e4b9fb
exec foreach outside the engine.
zheng-da Apr 5, 2018
42ac88a
Implements forward of foreach.
zheng-da Apr 5, 2018
3318797
Add support for variable numbers of inputs and outputs.
zheng-da Apr 6, 2018
7670ace
Add a python wrapper for foreach.
zheng-da Apr 6, 2018
d98f878
Fix the order of inputs.
zheng-da Apr 6, 2018
4d6779c
hide C version of foreach.
zheng-da Apr 7, 2018
43f8c1b
fix a bug temporarily.
zheng-da Apr 7, 2018
efb6ba0
add test with lstm.
zheng-da Apr 6, 2018
7593275
Test free variables.
zheng-da Apr 7, 2018
6d4b90b
change for the new interface of InputGraph attribute.
zheng-da Apr 9, 2018
bde96cd
Add attribute to the subgraph.
zheng-da Apr 11, 2018
8bb020e
Handle free variables.
zheng-da Apr 11, 2018
a4998f6
Get all input symbols of a subgraph.
zheng-da Apr 13, 2018
409bbe2
Fix shape, dtype and storage inference.
zheng-da Apr 13, 2018
9660340
reorganize the output of foreach.
zheng-da Apr 14, 2018
a0f8d52
Add a gluon RNN unroll with symbol foreach.
zheng-da Apr 14, 2018
f462cc7
print unnecessary print.
zheng-da Apr 16, 2018
9328034
have imperative and symbolic foreach.
zheng-da Apr 16, 2018
62d767f
Fix an error after moving foreach.
zheng-da Apr 18, 2018
f8c4383
Fix imperative foreach
zheng-da Apr 18, 2018
15d3619
Fix a minor problem.
zheng-da Apr 24, 2018
191fec2
Use CachedOp to execute subgraph.
zheng-da Apr 30, 2018
65f98e9
update TODO.
zheng-da May 1, 2018
5e9ec40
make foreach op use FStatefulComputeEx.
zheng-da May 1, 2018
f3ce49c
Add backward.
zheng-da May 2, 2018
0f111ff
Fix bugs.
zheng-da May 4, 2018
7688b74
enable backward test in lstm.
zheng-da May 4, 2018
f1141a4
Fix a bug in foreach backward for free variables.
zheng-da May 7, 2018
8f8e51d
change for the new CachedOp.
zheng-da May 9, 2018
05ed08d
Detect the backward computation.
zheng-da May 9, 2018
610a9dc
Fix bugs in foreach.
zheng-da May 9, 2018
9b4ac3c
fix tests.
zheng-da May 10, 2018
a08de42
update tests.
zheng-da May 11, 2018
5c8a187
check state shape.
zheng-da May 12, 2018
e6b53bc
enable nested foreach.
zheng-da May 14, 2018
65c8515
remove print.
zheng-da May 16, 2018
7e05c98
fix a bug in test.
zheng-da May 17, 2018
fa2b9bd
handle infer storage type for backward.
zheng-da May 18, 2018
fcf4c34
address comments.
zheng-da May 18, 2018
54f6efa
address comments.
zheng-da May 18, 2018
a8f86f8
move some common functions out.
zheng-da May 18, 2018
0f894fa
address comments.
zheng-da May 18, 2018
f356460
fix lint.
zheng-da May 18, 2018
6bc3d56
Fix lint.
zheng-da May 18, 2018
88ae9bf
add doc.
zheng-da May 19, 2018
1f03f01
undo modification in imperative.h
zheng-da May 19, 2018
28ba842
add doc and remove example code.
zheng-da May 19, 2018
14ca454
fix lint.
zheng-da May 19, 2018
242455a
fix lint.
zheng-da May 19, 2018
3f5d207
Fix lint.
zheng-da May 19, 2018
cd7f94b
make nd.foreach and sym.foreach consistent.
zheng-da May 21, 2018
3936aff
fix compile error.
zheng-da May 21, 2018
fb23c90
Fix bugs in MKLDNN.
zheng-da May 16, 2018
0a27af1
address comments.
zheng-da May 21, 2018
76006dd
update.
zheng-da May 21, 2018
a42a5a0
check for loop only works for dense arrays.
zheng-da May 22, 2018
d2eb153
move control flow op out of nn/
zheng-da May 22, 2018
767857f
fix include.
zheng-da May 22, 2018
471e6d2
add a test in gluon.
zheng-da May 22, 2018
580f294
work for GPU.
zheng-da May 22, 2018
ae9340d
small fix.
zheng-da May 22, 2018
977f562
remove subgraph_name
zheng-da May 22, 2018
fab06e5
create loop state for reuse in the future.
zheng-da May 22, 2018
03e4992
Merge branch 'foreach' of https://github.com/zheng-da/incubator-mxnet…
zheng-da May 22, 2018
08fbd04
move code.
zheng-da May 22, 2018
4550f85
Revert "remove subgraph_name"
zheng-da May 23, 2018
d3c4f6f
cut graph.
zheng-da May 25, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,28 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **name);

/*!
* \brief Get the input symbols of the graph.
* \param sym The graph.
* \param inputs The input symbols of the graph.
* \param input_size the number of input symbols returned.
*/
MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs,
int *input_size);

/*!
* \brief Cut a subgraph whose nodes are marked with a subgraph attribute.
* The input graph will be modified. A variable node will be created for each
* edge that connects to nodes outside the subgraph. The outside nodes that
* connect to the subgraph will be returned.
* \param sym The graph.
* \param inputs The nodes that connect to the subgraph.
* \param input_size The number of such nodes.
*/
MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs,
int *input_size);

/*!
* \brief Get the detailed information about atomic symbol.
* \param creator the AtomicSymbolCreator.
Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@ class NDArray {
NDArray MKLDNNDataReshape(const TShape &shape) const;
#endif

const nnvm::NodeEntry &entry() const {
return entry_;
}

/*!
* \brief Save list of ndarray into the Stream.x
* \param fo The stream of output.
Expand Down
4 changes: 3 additions & 1 deletion include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ enum OpReqType {
* \sa Resource
*/
struct OpContext {
/*! \brief whether there is a backward phase to compute gradients. */
bool need_grad;
/*! \brief whether it is training phase */
int is_train;
bool is_train;
/*! \brief RunContext related resources */
RunContext run_ctx;
/*! \brief the callback when operation completes, used by asynchronize ops */
Expand Down
95 changes: 95 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import math
from ..context import current_context
from ..random import uniform
from ..base import _as_list
from . import ndarray
try:
from .gen_contrib import *
except ImportError:
Expand Down Expand Up @@ -96,3 +98,96 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long

def foreach(body, data, init_states):
"""Run a for loop with user-defined computation over NDArrays on dimension 0.

This operator simulates a for loop and body has the computation for an iteration
of the for loop. It runs the computation in body on each slice from the input
NDArrays.

body takes two arguments as input and outputs a tuple of two elements,
as illustrated below:

out, states = body(data1, states)

data1 can be either an NDArray or a list of NDArrays. If data is an NDArray,
data1 is an NDArray. Otherwise, data1 is a list of NDArrays and has the same
size as data. states is a list of NDArrays and have the same size as init_states.
Similarly, out can be either an NDArray or a list of NDArrays, which are concatenated
as the first output of foreach; states from the last execution of body
are the second output of foreach.

The computation done by this operator is equivalent to the pseudo code below
when the input data is NDArray:

states = init_states
outs = []
for i in data.shape[0]:
s = data[i]
out, states = body(s, states)
outs.append(out)
outs = stack(*outs)


Parameters
----------
body : a Python function.
Define computation in an iteration.
data: an NDArray or a list of NDArrays.
The input data.
init_states: an NDArray or a list of NDArrays.
The initial values of the loop states.
name: string.
The name of the operator.

Returns
-------
outputs: an NDArray or a list of NDArrays.
The output data concatenated from the output of all iterations.
states: a list of NDArrays.
The loop states in the last iteration.

Examples
--------
>>> step = lambda data, states: (data + states[0], [states[0] * 2])
>>> data = mx.nd.random.uniform(shape=(2, 10))
>>> states = [mx.nd.random.uniform(shape=(10))]
>>> outs, states = mx.nd.contrib.foreach(step, data, states)
"""

def check_input(inputs, in_type, msg):
is_NDArray_or_list = True
if isinstance(inputs, list):
for i in inputs:
if not isinstance(i, in_type):
is_NDArray_or_list = False
break
else:
is_NDArray_or_list = isinstance(inputs, in_type)
assert is_NDArray_or_list, msg

check_input(data, ndarray.NDArray, "data should be an NDArray or a list of NDArrays")
check_input(init_states, ndarray.NDArray,
"init_states should be an NDArray or a list of NDArrays")

not_data_list = isinstance(data, ndarray.NDArray)
not_state_list = isinstance(init_states, ndarray.NDArray)
num_iters = data.shape[0] if not_data_list else data[0].shape[0]
states = init_states
outputs = []
for i in range(num_iters):
if not_data_list:
eles = data[i]
else:
eles = [d[i] for d in data]
outs, states = body(eles, states)
outs = _as_list(outs)
outputs.append(outs)
outputs = zip(*outputs)
for j, out in enumerate(outputs):
outputs[j] = ndarray.op.stack(*out)

if not_data_list:
outputs = outputs[0]
return (outputs, states)
201 changes: 201 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@
# pylint: disable=wildcard-import, unused-wildcard-import
"""Contrib Symbol API of MXNet."""
import math
import ctypes
import re

from .random import uniform
from .symbol import Symbol
try:
from .gen_contrib import *
except ImportError:
pass

from . import symbol
from ..base import _LIB, c_array, check_call
from ..base import SymbolHandle, _as_list
from ..attribute import AttrScope

__all__ = ["rand_zipfian"]

def rand_zipfian(true_classes, num_sampled, range_max):
Expand Down Expand Up @@ -91,3 +99,196 @@ def rand_zipfian(true_classes, num_sampled, range_max):
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled

def _get_graph_inputs(subg):
num_handles = ctypes.c_int(1000)
handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)])
check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles, ctypes.byref(num_handles)))

syms = []
for i in range(num_handles.value):
s = Symbol(handles[i])
syms.append(s)
return syms

def _cut_subgraph(subg):
num_handles = ctypes.c_int(1000)
handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)])
check_call(_LIB.MXSymbolCutSubgraph(subg.handle, handles, ctypes.byref(num_handles)))

syms = []
for i in range(num_handles.value):
s = Symbol(handles[i])
syms.append(s)
return syms

def foreach(body, data, init_states, name="foreach"):
"""Run a for loop with user-defined computation over Symbols on dimension 0.

This operator simulates a for loop and body has the computation for an iteration
of the for loop. It runs the computation in body on each slice from the input
NDArrays.

body takes two arguments as input and outputs a tuple of two elements,
as illustrated below:

out, states = body(data1, states)

data1 can be either a symbol or a list of symbols. If data is a symbol,
data1 is a symbol. Otherwise, data1 is a list of symbols and has the same
size as data. states is a list of symbols and have the same size as init_states.
Similarly, out can be either a symbol or a list of symbols, which are concatenated
as the first output of foreach; states from the last execution of body
are the second output of foreach.

The computation done by this operator is equivalent to the pseudo code below
when the input data is NDArray:

states = init_states
outs = []
for i in data.shape[0]:
s = data[i]
out, states = body(s, states)
outs.append(out)
outs = stack(*outs)


Parameters
----------
body : a Python function.
Define computation in an iteration.
data: a symbol or a list of symbols.
The input data.
init_states: a symbol or a list of symbols.
The initial values of the loop states.
name: string.
The name of the operator.

Returns
-------
outputs: a Symbol or a list of Symbols.
The output data concatenated from the output of all iterations.
states: a list of Symbols.
The loop states in the last iteration.

Examples
--------
>>> step = lambda data, states: (data + states[0], [states[0] * 2])
>>> data = mx.sym.var('data')
>>> states = [mx.sym.var('state')]
>>> outs, states = mx.sym.contrib.foreach(step, data, states)
"""

def check_data(inputs, in_type, msg):
is_NDArray_or_list = True
if isinstance(inputs, list):
for i in inputs:
if not isinstance(i, in_type):
is_NDArray_or_list = False
break
else:
is_NDArray_or_list = isinstance(inputs, in_type)
assert is_NDArray_or_list, msg

check_data(data, symbol.Symbol, "data should be an NDArray or a list of NDArrays")
check_data(init_states, symbol.Symbol,
"init_states should be an NDArray or a list of NDArrays")
not_state_list = isinstance(init_states, symbol.Symbol)

# TODO(zhengda) If the input python function references to the symbols outside
# the python function, we need to prune the computation graph constructed from
# the function. One way of doing it is to mark the nodes in the computation graph
# with AttrScope and prune the nodes without the special attribute.
with AttrScope(subgraph_name=name):
if isinstance(data, list):
in_eles = [symbol.var(sym.name) for sym in data]
else:
in_eles = symbol.var(data.name)
if isinstance(init_states, list):
states = [symbol.var(s.name) for s in init_states]
else:
states = symbol.var(init_states.name)
sym_out, sym_states = body(in_eles, states)

check_data(sym_out, symbol.Symbol,
"the output should be an NDArray or a list of NDArrays")
check_data(sym_states, symbol.Symbol,
"the output states should be an NDArray or a list of NDArrays")
if isinstance(sym_states, list):
assert isinstance(init_states, list) and len(sym_states) == len(init_states), \
"the number of output states (%d) should be the same as input states (%d)" \
% (len(sym_states), len(init_states))

if isinstance(sym_out, list):
flat_out = sym_out
else:
flat_out = [sym_out]
num_out_data = len(flat_out)
if isinstance(sym_states, list):
for s in sym_states:
# There is a problem if the outputs are the same as the inputs
# or the first output. By calling identity, we can make sure that
# all symbols will refer to different NDArrays.
flat_out.append(symbol.op.identity(s))
else:
flat_out.append(symbol.op.identity(sym_states))
g = symbol.Group(flat_out)

cut_syms = _cut_subgraph(g)
input_syms = _get_graph_inputs(g)

# Here we need to find out how the input symbols are ordered as well as
# where the loop states are located in the list of inputs.

# This dict contains the symbols of the subgraph.
input_syms = {sym.name:sym for sym in input_syms}
gin_names = input_syms.keys()
# This array contains the symbols for the inputs of foreach.
# They are ordered according to the inputs of the subgraph.
states_map = {sym.name:sym for sym in init_states}
state_names = states_map.keys()
data_syms = _as_list(data)
data_map = {sym.name:sym for sym in data_syms}
data_names = data_map.keys()

ordered_ins = []
in_state_locs = []
in_data_locs = []
for in_name in g.list_inputs():
assert in_name in gin_names, "The input variable %s can't be found in graph inputs: %s" \
% (in_name, str(gin_names))
if in_name in state_names:
ordered_ins.append(states_map[in_name])
in_state_locs.append(len(ordered_ins) - 1)
elif in_name in data_names:
ordered_ins.append(data_map[in_name])
in_data_locs.append(len(ordered_ins) - 1)
else:
# The remaining inputs are the ones cut from the original graph.
# The names of these variable nodes contain the index in cut_syms.
m = re.search(r'\d+$', in_name)
idx = int(m.group()) if m else None
assert idx < len(cut_syms)
ordered_ins.append(cut_syms[idx])

num_outputs = len(flat_out)
num_states = len(state_names)
ret = symbol._internal._foreach(g, *ordered_ins, num_outputs=num_outputs,
num_out_data=num_out_data, in_state_locs=in_state_locs,
in_data_locs=in_data_locs)
if num_outputs - num_states > 1:
outs = []
for i in range(num_outputs - num_states):
outs.append(ret[i])
else:
outs = ret[0]
states = []
for i in range(num_states):
states.append(ret[num_outputs - num_states + i])

if not_state_list:
# If there is only one input state, there should be only one output state.
assert len(states) == 1
states = states[0]

return (outs, states)
Loading