diff --git a/ci/docker/docker-compose.yml b/ci/docker/docker-compose.yml index c9e410efcccb..0a3f320f42b9 100644 --- a/ci/docker/docker-compose.yml +++ b/ci/docker/docker-compose.yml @@ -58,7 +58,7 @@ services: dockerfile: Dockerfile.build.centos7 target: base args: - BASE_IMAGE: nvidia/cuda:10.2-cudnn7-devel-centos7 + BASE_IMAGE: nvidia/cuda:10.2-cudnn8-devel-centos7 cache_from: - ${DOCKER_CACHE_REGISTRY}/build.centos7_gpu_cu102:latest centos7_gpu_cu110: diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 8c429bcfd0ac..4ca1a94704cf 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -472,6 +472,9 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), ctx = [context.current_context()] if isinstance(ctx, Context): ctx = [ctx] + if isinstance(self.init, initializer.RNNFused): + self.init.set_initializer(init if init else default_init) + init = default_init = self.init if init is None: init = default_init if self.init is None else self.init if not shape_is_known(self.shape): diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 59056de6ce7b..8a854ca3f0c3 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,7 +23,7 @@ __all__ = ['RNN', 'LSTM', 'GRU'] -from ... import np, npx, context +from ... import np, npx, context, initializer from .. import HybridBlock, tensor_types from ..parameter import Parameter from ...util import use_np @@ -50,11 +50,6 @@ def __init__(self, hidden_size, num_layers, layout, self._dropout = dropout self._dir = 2 if bidirectional else 1 self._input_size = input_size - self._i2h_weight_initializer = i2h_weight_initializer - self._h2h_weight_initializer = h2h_weight_initializer - self._i2h_bias_initializer = i2h_bias_initializer - self._h2h_bias_initializer = h2h_bias_initializer - self._h2r_weight_initializer = h2r_weight_initializer self._lstm_state_clip_min = lstm_state_clip_min self._lstm_state_clip_max = lstm_state_clip_max self._lstm_state_clip_nan = lstm_state_clip_nan @@ -64,48 +59,17 @@ def __init__(self, hidden_size, num_layers, layout, self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] - ng, ni, nh = self._gates, input_size, hidden_size - if not projection_size: - for i in range(num_layers): - for j in ['l', 'r'][:self._dir]: - self._register_param('{}{}_i2h_weight'.format(j, i), - shape=(ng*nh, ni), - init=i2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_h2h_weight'.format(j, i), - shape=(ng*nh, nh), - init=h2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_i2h_bias'.format(j, i), - shape=(ng*nh,), - init=i2h_bias_initializer, dtype=dtype) - self._register_param('{}{}_h2h_bias'.format(j, i), - shape=(ng*nh,), - init=h2h_bias_initializer, dtype=dtype) - ni = nh * self._dir - else: - ps = self._projection_size - for i in range(num_layers): - for j in ['l', 'r'][:self._dir]: - self._register_param('{}{}_i2h_weight'.format(j, i), - shape=(ng*nh, ni), - init=i2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_h2h_weight'.format(j, i), - shape=(ng*nh, ps), - init=h2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_i2h_bias'.format(j, i), - shape=(ng*nh,), - init=i2h_bias_initializer, dtype=dtype) - self._register_param('{}{}_h2h_bias'.format(j, i), - shape=(ng*nh,), - init=h2h_bias_initializer, dtype=dtype) - self._register_param('{}{}_h2r_weight'.format(j, i), - shape=(ps, nh), - init=h2r_weight_initializer, dtype=dtype) - ni = ps * self._dir - - def _register_param(self, name, shape, init, dtype): - p = Parameter(name, shape=shape, init=init, allow_deferred_init=True, dtype=dtype) - setattr(self, name, p) - return p + param_initializer = initializer.RNNFused( + mode, num_layers, hidden_size, + bidirectional, projection_size, + i2h_weight_initializer=i2h_weight_initializer, + h2h_weight_initializer=h2h_weight_initializer, + i2h_bias_initializer=i2h_bias_initializer, + h2h_bias_initializer=h2h_bias_initializer, + h2r_weight_initializer=h2r_weight_initializer) + + self.rnn_param = Parameter('rnn_param', shape=(-1,), init=param_initializer, + allow_deferred_init=True, dtype=dtype) def __repr__(self): s = '{name}({mapping}, {_layout}' @@ -116,8 +80,7 @@ def __repr__(self): if self._dir == 2: s += ', bidirectional' s += ')' - shape = self.l0_i2h_weight.shape - mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates) + mapping = '{0} -> {1}'.format(self._input_size if self._input_size else None, self._hidden_size) return s.format(name=self.__class__.__name__, mapping=mapping, **self.__dict__) @@ -196,37 +159,26 @@ def forward(self, inputs, states, sequence_length=None): def infer_shape(self, inputs, *args): assert inputs.ndim == 3, \ "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]" - if not self._projection_size: - step = self._hidden_size - else: - step = self._projection_size - ni = inputs.shape[2] - for i in range(self._num_layers): - for j in ['l', 'r'][:self._dir]: - name = '{}{}_i2h_weight'.format(j, i) - getattr(self, name).shape = (self._gates*self._hidden_size, ni) - ni = step * self._dir + self._input_size = inputs.shape[2] + ng, ni, nh = self._gates, inputs.shape[2], self._hidden_size + + size = nh * self._dir * ng + size1 = (ni + nh + 2) * size # first layer size + size2 = (nh * self._dir + nh + 2) * size # second layer size + if self._projection_size: + size1 = (ni + self._projection_size + 2) * size # first layer size + size2 = (self._projection_size * self._dir + \ + self._projection_size + 2) * size # second layer size + param_size = size1 + (self._num_layers - 1) * size2 + if self._projection_size: + param_size += self._projection_size * nh * self._num_layers * self._dir + self.rnn_param.shape = (param_size, ) def _forward_kernel(self, inputs, states, sequence_length): """ forward using CUDNN or CPU kenrel""" ctx = inputs.ctx if self._layout == 'NTC': inputs = np.swapaxes(inputs, 0, 1) - if self._projection_size is None: - params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1) - for t in ['weight', 'bias'] - for l in range(self._num_layers) - for d in ['l', 'r'][:self._dir] - for g in ['i2h', 'h2h']) - else: - params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1) - for t in ['weight', 'bias'] - for l in range(self._num_layers) - for d in ['l', 'r'][:self._dir] - for g in ['i2h', 'h2h', 'h2r'] - if g != 'h2r' or t != 'bias') - - params = np.concatenate(params, axis=0) if self._use_sequence_length: rnn_args = states + [sequence_length] @@ -238,7 +190,8 @@ def _forward_kernel(self, inputs, states, sequence_length): new_args = args.as_in_ctx(ctx) rnn_args_ctx.append(new_args) - rnn = npx.rnn(inputs, params, *rnn_args_ctx, use_sequence_length=self._use_sequence_length, + rnn = npx.rnn(inputs, self.rnn_param.data().as_in_ctx(ctx), *rnn_args_ctx, + use_sequence_length=self._use_sequence_length, state_size=self._hidden_size, projection_size=self._projection_size, num_layers=self._num_layers, bidirectional=self._dir == 2, p=self._dropout, state_outputs=True, mode=self._mode, diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index df641cf1ace5..f9ddd3b97ca5 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -504,3 +504,79 @@ def _check_block_input_np_ndarrays(inputs): for i in inputs: _check_block_input_np_ndarrays(i) # pylint: enable=no-else-raise + + +# pylint: disable=too-many-nested-blocks +def split_rnn_params(param, mode, num_layers, input_size, hidden_size, bidirectional=False, projection_size=None): + """Split rnn layer parameter into weight and bias in different layer. + + Parameters + ---------- + param : ndarray + The parameter of rnn layer. + mode : str + Mode of rnn. Supported modes: rnn_relu, rnn_tanh, lstm, gru + num_layers : int, default 1 + Number of recurrent layers. + input_size: int, default 0 + The number of expected features in the input x. + If not specified, it will be inferred from input. + hidden_size: int + The number of features in the hidden state h. + bidirectional: bool, default False + If `True`, becomes a bidirectional RNN. + projection_size: int, default None + The number of features after projection. + """ + gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] + dir = 2 if bidirectional else 1 + param_dict = {} + begin = 0 + if not projection_size: + for p in ['weight', 'bias']: + for l in range(num_layers): + for d in ['l', 'r'][:dir]: + for g in ['i2h', 'h2h']: + ni = input_size + if l != 0: + ni = hidden_size * dir + if g == 'h2h': + ni = hidden_size + shape0 = gates * hidden_size + if p == 'weight': + cur_len = shape0 * ni + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0, ni) + else: + cur_len = shape0 + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0,) + begin += cur_len + else: + for p in ['weight', 'bias']: + for l in range(num_layers): + for d in ['l', 'r'][:dir]: + for g in ['i2h', 'h2h', 'h2r']: + if g != 'h2r' or p != 'bias': + if g == 'h2r': + cur_len = projection_size * hidden_size + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len]. \ + reshape(projection_size, hidden_size) + else: + ni = input_size + if l != 0: + ni = projection_size * dir + if g == 'h2h': + ni = projection_size + shape0 = gates * hidden_size + if p == 'weight': + cur_len = shape0 * ni + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0, ni) + else: + cur_len = shape0 + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0,) + begin += cur_len + return param_dict diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 60a1d3d667cc..987b357f5756 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -711,3 +711,122 @@ def _init_weight(self, name, arr): # gate of the 4 LSTM gates, we modify the according values. num_hidden = int(arr.shape[0] / 4) arr[num_hidden:2*num_hidden] = self.forget_bias + + +@register +class RNNFused(Initializer): + """Initialize RNN fused parameter with bias part initialized to 0.0 and + weight initialized with random values uniformly sampled from a given range. + + Parameters + ---------- + mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required + the type of RNN to compute + num_layers : int (non-negative), required + number of stacked layers + state_size : int (non-negative), required + size of the state for each layer + bidirectional : boolean, optional, default=0 + whether to use bidirectional recurrent layers + projection_size : int or None, optional, default='None' + size of project size + scale : float, optional + The bound on the range of the generated random values for weights. + Values are generated from the range [-`scale`, `scale`]. + Default scale is 0.07. + """ + def __init__(self, mode, num_layers, state_size, bidirectional=False, + projection_size=None, i2h_weight_initializer=None, + h2h_weight_initializer=None, i2h_bias_initializer=None, + h2h_bias_initializer=None, h2r_weight_initializer=None): + super(RNNFused, self).__init__(mode=mode, num_layers=num_layers, + state_size=state_size, + bidirectional=bidirectional, + projection_size=projection_size, + i2h_weight_initializer=i2h_weight_initializer, + h2h_weight_initializer=h2h_weight_initializer, + i2h_bias_initializer=i2h_bias_initializer, + h2h_bias_initializer=h2h_bias_initializer, + h2r_weight_initializer=h2r_weight_initializer) + self.gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] + self.num_layers = num_layers + self.num_hidden = state_size + self.dir = 2 if bidirectional else 1 + self.projection_size = projection_size + self._i2h_weight_initializer = i2h_weight_initializer + self._h2h_weight_initializer = h2h_weight_initializer + self._i2h_bias_initializer = i2h_bias_initializer + self._h2h_bias_initializer = h2h_bias_initializer + self._h2r_weight_initializer = h2r_weight_initializer + + # pylint: disable=too-many-nested-blocks + def _init_weight(self, name, arr): + arr_len = arr.shape[0] + size = self.num_hidden * self.dir * self.gates + if not self.projection_size: + # second layer size + size2 = (self.num_hidden * self.dir + self.num_hidden + 2) * size + input_size = (arr_len - (self.num_layers - 1) * size2) // \ + size - 2 - self.num_hidden + else: + # second layer size + size2 = (self.projection_size * self.dir + self.projection_size + 2) * size + size_projection = self.projection_size * self.num_hidden * self.num_layers * self.dir + input_size = (arr_len - size_projection - (self.num_layers - 1) * size2) // \ + size - 2 - self.projection_size + begin = 0 + if not self.projection_size: + for param in ['weight', 'bias']: + for layer_num in range(self.num_layers): + for _ in range(self.dir): + for connect in ['i2h', 'h2h']: + num_inputs = input_size + if layer_num != 0: + num_inputs = self.num_hidden * self.dir + if connect == 'h2h': + num_inputs = self.num_hidden + shape0 = self.gates * self.num_hidden + if param == 'weight': + cur_len = shape0 * num_inputs + else: + cur_len = shape0 + self._init_util(param, connect, arr[begin:begin+cur_len]) + begin += cur_len + else: + for param in ['weight', 'bias']: + for layer_num in range(self.num_layers): + for _ in range(self.dir): + for connect in ['i2h', 'h2h', 'h2r']: + if connect != 'h2r' or param != 'bias': + if connect == 'h2r': + cur_len = self.projection_size * self.num_hidden + else: + num_inputs = input_size + if layer_num != 0: + num_inputs = self.projection_size * self.dir + if connect == 'h2h': + num_inputs = self.projection_size + shape0 = self.gates * self.num_hidden + if param == 'weight': + cur_len = shape0 * num_inputs + else: + cur_len = shape0 + self._init_util(param, connect, arr[begin:begin+cur_len]) + begin += cur_len + + def _init_util(self, param, connect, arr): + name = "_{}_{}_initializer".format(connect, param) + init = getattr(self, name) + create(init)(InitDesc(name, {'__init__': init}), arr) + + def set_initializer(self, init): + self._i2h_weight_initializer = \ + init if not self._i2h_weight_initializer else 'uniform' + self._h2h_weight_initializer = \ + init if not self._h2h_weight_initializer else 'uniform' + self._i2h_bias_initializer = \ + init if not self._i2h_bias_initializer else 'zero' + self._h2h_bias_initializer = \ + init if not self._i2h_bias_initializer else 'zero' + self._h2r_weight_initializer = \ + init if not self._h2r_weight_initializer else 'uniform' diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 134eab397640..94190abdb21e 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -97,9 +97,15 @@ def test_lstmp(): lstm_cell.initialize(ctx=ctx) layer_params = lstm_layer.collect_params() cell_params = lstm_cell.collect_params() + params = (weights['{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + + net_params_concat = mx.np.concatenate(params) + layer_params['rnn_param'].set_data(net_params_concat) for k, v in weights.items(): - layer_params['l0_' + k].set_data(v.copy()) - cell_params[k].set_data(v.copy()) + cell_params[k].set_data(v) with autograd.record(): layer_output = lstm_layer(lstm_input.copy()) cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), layout='TNC', @@ -108,8 +114,10 @@ def test_lstmp(): assert_almost_equal(layer_output, cell_output, rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() + layer_params_split = split_rnn_params(layer_params['rnn_param'].grad(),\ + 'lstm', 1, input_size, hidden_size, False, projection_size=projection_size) for k, _ in weights.items(): - layer_grad = layer_params['l0_' + k].grad() + layer_grad = layer_params_split['l0_' + k] cell_grad = cell_params[k].grad() print('checking gradient for {}'.format('lstm0_l0_' + k)) assert_almost_equal(layer_grad, cell_grad, rtol=rtol, atol=atol) @@ -196,6 +204,39 @@ def forward(self, inpt): weights['{}0_h2h_bias'.format( d)] = mx.np.random.uniform(size=(size * 4,)) + if proj_size: + params = (weights['{}0_{}_{}'.format(d, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for d in ['l', 'r'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + else: + params = (weights['{}0_{}_{}'.format(d, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for d in ['l', 'r'] + for g in ['i2h', 'h2h']) + + net_params_concat = mx.np.concatenate(params) + if proj_size: + params_left = (weights['l0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + else: + params_left = (weights['l0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + if proj_size: + params_right = (weights['r0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + else: + params_right = (weights['r0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + net_ref_left_params = mx.np.concatenate(params_left) + net_ref_right_params = mx.np.concatenate(params_right) net = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=True) ref_net = RefBiLSTM(size, proj_size) @@ -203,10 +244,9 @@ def forward(self, inpt): ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() - for k in weights: - net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', '_lstm_fwd.l0').replace( - 'r0', '_lstm_bwd.l0')].set_data(weights[k]) + net_params['rnn_param'].set_data(net_params_concat) + ref_net_params['_lstm_fwd.rnn_param'].set_data(net_ref_left_params) + ref_net_params['_lstm_bwd.rnn_param'].set_data(net_ref_right_params) data = mx.np.random.uniform(size=(11, 10, in_size)) mx.test_utils.assert_allclose(net(data), ref_net(data), rtol=1e-6) @@ -214,12 +254,7 @@ def forward(self, inpt): def check_layer_bidirectional_varseqlen(size, in_size): - weights = {} - for d in ['l', 'r']: - weights['{}0_i2h_weight'.format(d)] = mx.np.random.uniform(size=(size*4, in_size)) - weights['{}0_h2h_weight'.format(d)] = mx.np.random.uniform(size=(size*4, size)) - weights['{}0_i2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) - weights['{}0_h2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) + weight = mx.np.random.uniform(size=(784,)) net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True) ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False) @@ -227,9 +262,8 @@ def check_layer_bidirectional_varseqlen(size, in_size): ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() - for k in weights: - net_params[k].set_data(weights[k]) - ref_net_params[k].set_data(weights[k]) + net_params['rnn_param'].set_data(weight) + ref_net_params['rnn_param'].set_data(weight) batch_size = 10 num_timesteps = 11 @@ -269,11 +303,10 @@ def check_layer_bidirectional_varseqlen(size, in_size): ref_net_params = ref_net.collect_params() - for k in weights: - net_grad = net_params[k].grad() - ref_net_grad = ref_net_params[k].grad() - assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(), - rtol=1e-2, atol=1e-6) + net_grad = net_params['rnn_param'].grad() + ref_net_grad = ref_net_params['rnn_param'].grad() + assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(), + rtol=1e-2, atol=1e-6) @assert_raises_cudnn_not_satisfied(min_version='5.1.10') diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index d34519c332cc..db3636499c5c 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1362,7 +1362,7 @@ def forward(self, x): x = self.encoders[i](x) return x net = Network() - net.initialize(mx.init.Xavier(), ctx=mx.cpu()) + net.initialize(mx.init.Uniform(), ctx=mx.cpu()) net.hybridize() x = onp.random.rand(32, 10, 10) x = mx.np.array(x).as_in_context(mx.cpu()) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 6af620969ff8..0336baa8d14e 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -21,6 +21,7 @@ import copy from functools import partial from numpy.testing import assert_allclose +from mxnet.gluon.utils import split_rnn_params import pytest from mxnet.test_utils import almost_equal, assert_almost_equal, default_context from common import assert_raises_cudnn_not_satisfied, retry @@ -99,10 +100,15 @@ def test_lstmp(): fused_layer_params = fused_layer.collect_params() stack_layer_params = stack_layer.collect_params() - for name, value in fused_layer_params.items(): - w = mx.np.random.uniform(size=value.shape) - value.set_data(w.copy()) - stack_layer_params[name[1:].replace('_', '.', 1)].set_data(w.copy()) + fused_weight_shape = fused_layer_params['rnn_param'].shape + print(fused_weight_shape) + w = mx.np.zeros(shape=fused_weight_shape) + fused_layer_params['rnn_param'].set_data(w) + fused_layer_params_split = split_rnn_params(w, 'lstm', num_layers, input_size,\ + hidden_size, False, projection_size=projection_size) + + for name, value in fused_layer_params_split.items(): + stack_layer_params[name[1:].replace('_', '.', 1)].set_data(value) fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, @@ -135,11 +141,15 @@ def test_lstmp(): fused_layer_params = fused_layer.collect_params() stack_layer_params = stack_layer.collect_params() - for name, value in fused_layer_params.items(): - w = mx.np.random.uniform(size=value.shape) - value.set_data(w.copy()) + fused_weight_shape = fused_layer_params['rnn_param'].shape + w = mx.np.zeros(shape=fused_weight_shape) + fused_layer_params['rnn_param'].set_data(w) + fused_layer_params_split = split_rnn_params(w, 'lstm', num_layers, input_size,\ + hidden_size, True, projection_size=projection_size) + + for name, value in fused_layer_params_split.items(): cur = name.split("_")[0] - stack_layer_params["{}.{}_cell.{}".format(cur[1:], name[0], name[len(cur)+1:])].set_data(w.copy()) + stack_layer_params["{}.{}_cell.{}".format(cur[1:], name[0], name[len(cur)+1:])].set_data(value) fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, @@ -346,15 +356,28 @@ def forward(self, inpt): weights['{}0_i2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) weights['{}0_h2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) + params = (weights['{}0_{}_{}'.format(d, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for d in ['l', 'r'] + for g in ['i2h', 'h2h']) + net_params_concat = mx.np.concatenate(params) + params_left = (weights['l0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + params_right = (weights['r0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + net_ref_left_params = mx.np.concatenate(params_left) + net_ref_right_params = mx.np.concatenate(params_right) net = gluon.rnn.LSTM(size, bidirectional=True) ref_net = RefBiLSTM(size) net.initialize() ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() - for k in weights: - net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', '_lstm_fwd.l0').replace('r0', '_lstm_bwd.l0')].set_data(weights[k]) + net_params['rnn_param'].set_data(net_params_concat) + ref_net_params['_lstm_fwd.rnn_param'].set_data(net_ref_left_params) + ref_net_params['_lstm_bwd.rnn_param'].set_data(net_ref_right_params) data = mx.np.random.uniform(size=(11, 10, in_size)) assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy(), rtol=1e-04, atol=1e-02) @@ -656,7 +679,7 @@ def test_rnn_layers_fp16(): run_rnn_layers('float16', 'float32', mx.gpu()) -def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4): +def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4): x = mx.np.random.normal(size=(1, 5, input_size)) fused_begin_state = fused_layer.begin_state(1) stack_states = stack_layer.begin_state(batch_size=1) @@ -666,16 +689,25 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz stack_layer.initialize() stack_layer_params = stack_layer.collect_params() - for name, value in fused_layer_params.items(): - if 'weight' in name: - w = mx.np.zeros(shape=value.shape) - else: - w = mx.np.random.normal(size=value.shape) - value.set_data(w.copy()) + fused_weight_shape = fused_layer_params['rnn_param'].shape + w = mx.np.zeros(shape=fused_weight_shape) + fused_layer_params_split = split_rnn_params(w, mode, num_layers, input_size, hidden_size, bidirectional) + for name, value in fused_layer_params_split.items(): + if 'bias' in name: + fused_layer_params_split[name] = mx.np.random.normal(size=value.shape) + _dir = 2 if bidirectional else 1 + params = (fused_layer_params_split['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for l in range(num_layers) + for d in ['l', 'r'][:_dir] + for g in ['i2h', 'h2h']) + fused_params = mx.np.concatenate(params) + fused_layer_params['rnn_param'].set_data(fused_params) + for name, value in fused_layer_params_split.items(): cur = name.split('_')[0] num = cur[1:] stack_name = ('{}.{}_cell.'.format(num, name[0]) if bidirectional else num + '.' ) + name[len(cur)+1:] - stack_layer_params[stack_name].set_data(w.copy()) + stack_layer_params[stack_name].set_data(value) fx = x.copy() sx = x.copy() @@ -686,7 +718,8 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz l = loss(fused_out, y).mean() l.backward() mx.npx.waitall() - fused_grads = dict([(name, p.grad()) for name, p in fused_layer.collect_params().items()]) + fused_grads = split_rnn_params(fused_layer.collect_params()['rnn_param'].data().grad,\ + mode, num_layers, input_size, hidden_size, bidirectional) fused_input_grad = fx.grad.asnumpy() sx.attach_grad() @@ -741,7 +774,7 @@ def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, for _ in range(num_layers): stack_layer.add(stack_op(hidden_size)) stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) + check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size) def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): @@ -755,7 +788,7 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, l stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size), stack_op(hidden_size))) stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) + check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size, bidirectional=True) @mx.util.use_np @@ -851,7 +884,7 @@ def test_layer_fill_shape(): layer.hybridize() check_rnn_layer_forward(layer, mx.np.ones((3, 2, 7))) print(layer) - assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1] + assert layer.rnn_param.shape[0] == 760 @pytest.mark.serial