From 393a23489fbb45eb0f7dfa4ff9a3a73a6c757a0a Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 24 Jun 2021 11:31:52 -0700 Subject: [PATCH 01/19] use rnn_params --- python/mxnet/gluon/rnn/rnn_layer.py | 122 ++++++---------------------- 1 file changed, 26 insertions(+), 96 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 59056de6ce7b..167dc41e4389 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -34,9 +34,7 @@ class _RNNLayer(HybridBlock): """Implementation of recurrent layers.""" def __init__(self, hidden_size, num_layers, layout, dropout, bidirectional, input_size, - i2h_weight_initializer, h2h_weight_initializer, - i2h_bias_initializer, h2h_bias_initializer, - mode, projection_size, h2r_weight_initializer, + param_initializer, mode, projection_size, lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, dtype, use_sequence_length=False, **kwargs): super(_RNNLayer, self).__init__(**kwargs) @@ -50,11 +48,6 @@ def __init__(self, hidden_size, num_layers, layout, self._dropout = dropout self._dir = 2 if bidirectional else 1 self._input_size = input_size - self._i2h_weight_initializer = i2h_weight_initializer - self._h2h_weight_initializer = h2h_weight_initializer - self._i2h_bias_initializer = i2h_bias_initializer - self._h2h_bias_initializer = h2h_bias_initializer - self._h2r_weight_initializer = h2r_weight_initializer self._lstm_state_clip_min = lstm_state_clip_min self._lstm_state_clip_max = lstm_state_clip_max self._lstm_state_clip_nan = lstm_state_clip_nan @@ -64,48 +57,8 @@ def __init__(self, hidden_size, num_layers, layout, self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] - ng, ni, nh = self._gates, input_size, hidden_size - if not projection_size: - for i in range(num_layers): - for j in ['l', 'r'][:self._dir]: - self._register_param('{}{}_i2h_weight'.format(j, i), - shape=(ng*nh, ni), - init=i2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_h2h_weight'.format(j, i), - shape=(ng*nh, nh), - init=h2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_i2h_bias'.format(j, i), - shape=(ng*nh,), - init=i2h_bias_initializer, dtype=dtype) - self._register_param('{}{}_h2h_bias'.format(j, i), - shape=(ng*nh,), - init=h2h_bias_initializer, dtype=dtype) - ni = nh * self._dir - else: - ps = self._projection_size - for i in range(num_layers): - for j in ['l', 'r'][:self._dir]: - self._register_param('{}{}_i2h_weight'.format(j, i), - shape=(ng*nh, ni), - init=i2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_h2h_weight'.format(j, i), - shape=(ng*nh, ps), - init=h2h_weight_initializer, dtype=dtype) - self._register_param('{}{}_i2h_bias'.format(j, i), - shape=(ng*nh,), - init=i2h_bias_initializer, dtype=dtype) - self._register_param('{}{}_h2h_bias'.format(j, i), - shape=(ng*nh,), - init=h2h_bias_initializer, dtype=dtype) - self._register_param('{}{}_h2r_weight'.format(j, i), - shape=(ps, nh), - init=h2r_weight_initializer, dtype=dtype) - ni = ps * self._dir - - def _register_param(self, name, shape, init, dtype): - p = Parameter(name, shape=shape, init=init, allow_deferred_init=True, dtype=dtype) - setattr(self, name, p) - return p + self.rnn_param = Parameter('rnn_param', shape=(-1,), init=param_initializer, + allow_deferred_init=True, dtype=dtype) def __repr__(self): s = '{name}({mapping}, {_layout}' @@ -116,7 +69,7 @@ def __repr__(self): if self._dir == 2: s += ', bidirectional' s += ')' - shape = self.l0_i2h_weight.shape + shape = self.rnn_param.shape mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates) return s.format(name=self.__class__.__name__, mapping=mapping, @@ -196,37 +149,25 @@ def forward(self, inputs, states, sequence_length=None): def infer_shape(self, inputs, *args): assert inputs.ndim == 3, \ "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]" - if not self._projection_size: - step = self._hidden_size - else: - step = self._projection_size - ni = inputs.shape[2] - for i in range(self._num_layers): - for j in ['l', 'r'][:self._dir]: - name = '{}{}_i2h_weight'.format(j, i) - getattr(self, name).shape = (self._gates*self._hidden_size, ni) - ni = step * self._dir + ng, ni, nh = self._gates, inputs.shape[2], self._hidden_size + + size = nh * self._dir * ng + size1 = (ni + nh + 2) * size # first layer size + size2 = (nh * self._dir + nh + 2) * size # second layer size + if self._projection_size: + size1 = (ni + self._projection_size + 2) * size # first layer size + size2 = (self._projection_size * self._dir + \ + self._projection_size + 2) * size # second layer size + param_size = size1 + (self._num_layers - 1) * size2 + if self._projection_size: + param_size += self._projection_size * nh * self._num_layers * self._dir + self.rnn_param.shape = (param_size, ) def _forward_kernel(self, inputs, states, sequence_length): """ forward using CUDNN or CPU kenrel""" ctx = inputs.ctx if self._layout == 'NTC': inputs = np.swapaxes(inputs, 0, 1) - if self._projection_size is None: - params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1) - for t in ['weight', 'bias'] - for l in range(self._num_layers) - for d in ['l', 'r'][:self._dir] - for g in ['i2h', 'h2h']) - else: - params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1) - for t in ['weight', 'bias'] - for l in range(self._num_layers) - for d in ['l', 'r'][:self._dir] - for g in ['i2h', 'h2h', 'h2r'] - if g != 'h2r' or t != 'bias') - - params = np.concatenate(params, axis=0) if self._use_sequence_length: rnn_args = states + [sequence_length] @@ -238,7 +179,8 @@ def _forward_kernel(self, inputs, states, sequence_length): new_args = args.as_in_ctx(ctx) rnn_args_ctx.append(new_args) - rnn = npx.rnn(inputs, params, *rnn_args_ctx, use_sequence_length=self._use_sequence_length, + rnn = npx.rnn(inputs, self.rnn_param.data().as_in_ctx(ctx), *rnn_args_ctx, + use_sequence_length=self._use_sequence_length, state_size=self._hidden_size, projection_size=self._projection_size, num_layers=self._num_layers, bidirectional=self._dir == 2, p=self._dropout, state_outputs=True, mode=self._mode, @@ -334,15 +276,11 @@ class RNN(_RNNLayer): >>> output, hn = layer(input, h0) """ def __init__(self, hidden_size, num_layers=1, activation='relu', - layout='TNC', dropout=0, bidirectional=False, - i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', + layout='TNC', dropout=0, bidirectional=False, param_initializer=None, input_size=0, dtype='float32', **kwargs): super(RNN, self).__init__(hidden_size, num_layers, layout, - dropout, bidirectional, input_size, - i2h_weight_initializer, h2h_weight_initializer, - i2h_bias_initializer, h2h_bias_initializer, - 'rnn_'+activation, None, None, None, None, False, + dropout, bidirectional, input_size, param_initializer, + 'rnn_'+activation, None, None, None, False, dtype, **kwargs) def state_info(self, batch_size=0): @@ -451,16 +389,12 @@ class LSTM(_RNNLayer): """ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, - i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - projection_size=None, h2r_weight_initializer=None, + param_initializer=None, projection_size=None, state_clip_min=None, state_clip_max=None, state_clip_nan=False, dtype='float32', **kwargs): super(LSTM, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, - i2h_weight_initializer, h2h_weight_initializer, - i2h_bias_initializer, h2h_bias_initializer, - 'lstm', projection_size, h2r_weight_initializer, + param_initializer, 'lstm', projection_size, state_clip_min, state_clip_max, state_clip_nan, dtype, **kwargs) @@ -560,14 +494,10 @@ class GRU(_RNNLayer): """ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, - i2h_weight_initializer=None, h2h_weight_initializer=None, - i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - dtype='float32', **kwargs): + param_initializer=None, dtype='float32', **kwargs): super(GRU, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, - i2h_weight_initializer, h2h_weight_initializer, - i2h_bias_initializer, h2h_bias_initializer, - 'gru', None, None, None, None, False, + param_initializer, 'gru', None, None, None, False, dtype, **kwargs) def state_info(self, batch_size=0): From b8bff7787c5802f1cfb56bc8e484ad54f8fd7a35 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 24 Jun 2021 20:07:15 -0700 Subject: [PATCH 02/19] add split rnn parameter in gluon.utils --- python/mxnet/gluon/rnn/rnn_layer.py | 2 +- python/mxnet/gluon/utils.py | 50 +++++++++++++++++++++++++ tests/python/unittest/test_gluon.py | 2 +- tests/python/unittest/test_gluon_rnn.py | 34 +++++++++-------- 4 files changed, 70 insertions(+), 18 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 167dc41e4389..69e67faf484a 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -58,7 +58,7 @@ def __init__(self, hidden_size, num_layers, layout, self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] self.rnn_param = Parameter('rnn_param', shape=(-1,), init=param_initializer, - allow_deferred_init=True, dtype=dtype) + allow_deferred_init=True, dtype=dtype) def __repr__(self): s = '{name}({mapping}, {_layout}' diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index df641cf1ace5..5e8f6790fe41 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -504,3 +504,53 @@ def _check_block_input_np_ndarrays(inputs): for i in inputs: _check_block_input_np_ndarrays(i) # pylint: enable=no-else-raise + + +def split_rnn_params(param, mode, num_layers, input_size, hidden_size, bidirectional, projection_size=None): + """Split rnn layer parameter into weight and bias in different layer.""" + gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] + dir = 2 if bidirectional else 1 + param_dict = {} + begin = 0 + if not projection_size: + for p in ['weight', 'bias']: + for l in range(num_layers): + for d in ['l', 'r'][:dir]: + for g in ['i2h', 'h2h']: + ni = input_size + if l != 0: + ni = hidden_size * dir + if g == 'h2h': + ni = hidden_size + shape0 = gates * hidden_size + if p == 'weight': + cur_len = shape0 * ni + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0, ni) + else: + cur_len = shape0 + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0,) + begin += cur_len + else: + for p in ['weight', 'bias']: + for l in range(num_layers): + for d in ['l', 'r'][:dir]: + for g in ['i2h', 'h2h', 'h2r']: + if g != 'h2r' or p != 'bias': + ni = input_size + if l != 0: + ni = projection_size * dir + if g == 'h2h': + ni = projection_size + shape0 = gates * hidden_size + if p == 'weight': + cur_len = shape0 * ni + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0, ni) + else: + cur_len = shape0 + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0,) + begin += cur_len + return param_dict diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 6d5e40c31ecf..bc6b63d54737 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1347,7 +1347,7 @@ def forward(self, x): x = self.encoders[i](x) return x net = Network() - net.initialize(mx.init.Xavier(), ctx=mx.cpu()) + net.initialize(mx.init.Uniform(), ctx=mx.cpu()) net.hybridize() x = onp.random.rand(32, 10, 10) x = mx.np.array(x).as_in_context(mx.cpu()) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 5a2661dddb54..172e055dc769 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -21,6 +21,7 @@ import copy from functools import partial from numpy.testing import assert_allclose +from mxnet.gluon.utils import split_rnn_params import pytest from mxnet.test_utils import almost_equal, assert_almost_equal, default_context from common import assert_raises_cudnn_not_satisfied, retry @@ -154,10 +155,10 @@ def test_lstmp(): @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_cpu_inference(): # should behave the same as lstm cell - EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], - [0.72045636, 0.72045636, 0.95215213, 0.95215213]], - [[0.95215213, 0.95215213, 0.72045636, 0.72045636], - [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) + EXPECTED_LSTM_OUTPUT = np.array([[[0.7564661, 0.7564661, 0.96265966, 0.96265966], + [0.7564661, 0.7564661, 0.96265966, 0.96265966]], + [[0.96265966, 0.96265966, 0.7564661, 0.7564661], + [0.96265966, 0.96265966, 0.7564661, 0.7564661]]]) x = mx.np.ones(shape=(2, 2, 2)) model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) model.initialize(mx.init.One()) @@ -656,7 +657,7 @@ def test_rnn_layers_fp16(): run_rnn_layers('float16', 'float32', mx.gpu()) -def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4): +def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4): x = mx.np.random.normal(size=(1, 5, input_size)) fused_begin_state = fused_layer.begin_state(1) stack_states = stack_layer.begin_state(batch_size=1) @@ -666,16 +667,15 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz stack_layer.initialize() stack_layer_params = stack_layer.collect_params() - for name, value in fused_layer_params.items(): - if 'weight' in name: - w = mx.np.zeros(shape=value.shape) - else: - w = mx.np.random.normal(size=value.shape) - value.set_data(w.copy()) + fused_weight_shape = fused_layer_params['rnn_param'].shape + w = mx.np.random.normal(size=fused_weight_shape) + fused_layer_params['rnn_param'].set_data(w.copy()) + fused_layer_params_split = split_rnn_params(w.copy(), mode, num_layers, input_size, hidden_size, bidirectional) + for name, value in fused_layer_params_split.items(): cur = name.split('_')[0] num = cur[1:] stack_name = ('{}.{}_cell.'.format(num, name[0]) if bidirectional else num + '.' ) + name[len(cur)+1:] - stack_layer_params[stack_name].set_data(w.copy()) + stack_layer_params[stack_name].set_data(value) fx = x.copy() sx = x.copy() @@ -686,7 +686,9 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz l = loss(fused_out, y).mean() l.backward() mx.npx.waitall() - fused_grads = dict([(name, p.grad()) for name, p in fused_layer.collect_params().items()]) + fused_layer_param_split = split_rnn_params(fused_layer.collect_params()['rnn_param'].data().grad,\ + mode, num_layers, input_size, hidden_size, bidirectional) + fused_grads = dict([(name, p) for name, p in fused_layer_param_split.items()]) fused_input_grad = fx.grad.asnumpy() sx.attach_grad() @@ -741,7 +743,7 @@ def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, for n in range(num_layers): stack_layer.add(stack_op(hidden_size)) stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) + check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size) def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): @@ -755,7 +757,7 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, l stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size), stack_op(hidden_size))) stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) + check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, input_size, hidden_size, bidirectional=True) @mx.util.use_np @@ -851,7 +853,7 @@ def test_layer_fill_shape(): layer.hybridize() check_rnn_layer_forward(layer, mx.np.ones((3, 2, 7))) print(layer) - assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1] + assert layer.rnn_param.shape[0] == 760 @pytest.mark.serial From 1512a6fdb391cabba1a1c6f17bb1f91ec8bef2e8 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 24 Jun 2021 20:17:29 -0700 Subject: [PATCH 03/19] update --- python/mxnet/gluon/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 5e8f6790fe41..30f5515447ff 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -506,6 +506,7 @@ def _check_block_input_np_ndarrays(inputs): # pylint: enable=no-else-raise +# pylint: disable=too-many-nested-blocks def split_rnn_params(param, mode, num_layers, input_size, hidden_size, bidirectional, projection_size=None): """Split rnn layer parameter into weight and bias in different layer.""" gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] From 26fa6331fef3e02d29c157a41ecd8a7552e04c39 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 24 Jun 2021 20:23:24 -0700 Subject: [PATCH 04/19] update --- python/mxnet/gluon/rnn/rnn_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 69e67faf484a..7ab5084f5805 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -69,8 +69,7 @@ def __repr__(self): if self._dir == 2: s += ', bidirectional' s += ')' - shape = self.rnn_param.shape - mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates) + mapping = '{0} -> {1}'.format(self._input_size if self._input_size else None, self._hidden_size) return s.format(name=self.__class__.__name__, mapping=mapping, **self.__dict__) @@ -149,6 +148,7 @@ def forward(self, inputs, states, sequence_length=None): def infer_shape(self, inputs, *args): assert inputs.ndim == 3, \ "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]" + self._input_size = inputs.shape[2] ng, ni, nh = self._gates, inputs.shape[2], self._hidden_size size = nh * self._dir * ng From fef6380787fba7e66b53289ab021e5a047ffdc33 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Fri, 25 Jun 2021 16:06:47 -0700 Subject: [PATCH 05/19] use zero weight --- tests/python/unittest/test_gluon_rnn.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 172e055dc769..9d858305dbb4 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -347,15 +347,28 @@ def forward(self, inpt): weights['{}0_i2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) weights['{}0_h2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) + params = (weights['{}0_{}_{}'.format(d, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for d in ['l', 'r'] + for g in ['i2h', 'h2h']) + net_params_concat = mx.np.concatenate(params) + params_left = (weights['l0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + params_right = (weights['r0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + net_ref_left_params = mx.np.concatenate(params_left) + net_ref_right_params = mx.np.concatenate(params_right) net = gluon.rnn.LSTM(size, bidirectional=True) ref_net = RefBiLSTM(size) net.initialize() ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() - for k in weights: - net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', '_lstm_fwd.l0').replace('r0', '_lstm_bwd.l0')].set_data(weights[k]) + net_params['rnn_param'].set_data(net_params_concat) + ref_net_params['_lstm_fwd.rnn_param'].set_data(net_ref_left_params) + ref_net_params['_lstm_bwd.rnn_param'].set_data(net_ref_right_params) data = mx.np.random.uniform(size=(11, 10, in_size)) assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy(), rtol=1e-04, atol=1e-02) @@ -668,7 +681,7 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, inpu stack_layer_params = stack_layer.collect_params() fused_weight_shape = fused_layer_params['rnn_param'].shape - w = mx.np.random.normal(size=fused_weight_shape) + w = mx.np.zeros(shape=fused_weight_shape) fused_layer_params['rnn_param'].set_data(w.copy()) fused_layer_params_split = split_rnn_params(w.copy(), mode, num_layers, input_size, hidden_size, bidirectional) for name, value in fused_layer_params_split.items(): @@ -686,9 +699,8 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, inpu l = loss(fused_out, y).mean() l.backward() mx.npx.waitall() - fused_layer_param_split = split_rnn_params(fused_layer.collect_params()['rnn_param'].data().grad,\ + fused_grads = split_rnn_params(fused_layer.collect_params()['rnn_param'].data().grad,\ mode, num_layers, input_size, hidden_size, bidirectional) - fused_grads = dict([(name, p) for name, p in fused_layer_param_split.items()]) fused_input_grad = fx.grad.asnumpy() sx.attach_grad() From 1cd0ce0879422038aa9d3b77ad95b55f84c59e51 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 29 Jun 2021 09:57:43 -0700 Subject: [PATCH 06/19] add rnn fused parameter initializer --- python/mxnet/gluon/rnn/rnn_layer.py | 6 +- python/mxnet/initializer.py | 93 +++++++++++++++++++++++++ tests/python/unittest/test_gluon_rnn.py | 4 +- 3 files changed, 100 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 7ab5084f5805..8161bc9595c2 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,7 +23,7 @@ __all__ = ['RNN', 'LSTM', 'GRU'] -from ... import np, npx, context +from ... import np, npx, context, initializer from .. import HybridBlock, tensor_types from ..parameter import Parameter from ...util import use_np @@ -57,6 +57,10 @@ def __init__(self, hidden_size, num_layers, layout, self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] + if not param_initializer: + param_initializer = initializer.RNNFused(mode, num_layers, hidden_size,\ + bidirectional, projection_size) + self.rnn_param = Parameter('rnn_param', shape=(-1,), init=param_initializer, allow_deferred_init=True, dtype=dtype) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 60a1d3d667cc..1efd0e254d49 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -711,3 +711,96 @@ def _init_weight(self, name, arr): # gate of the 4 LSTM gates, we modify the according values. num_hidden = int(arr.shape[0] / 4) arr[num_hidden:2*num_hidden] = self.forget_bias + + +@register +class RNNFused(Initializer): + """Initialize RNN fused parameter with bias part initialized to 0.0 and + weight initialized with random values uniformly sampled from a given range. + + Parameters + ---------- + mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required + the type of RNN to compute + num_layers : int (non-negative), required + number of stacked layers + state_size : int (non-negative), required + size of the state for each layer + bidirectional : boolean, optional, default=0 + whether to use bidirectional recurrent layers + projection_size : int or None, optional, default='None' + size of project size + scale : float, optional + The bound on the range of the generated random values for weights. + Values are generated from the range [-`scale`, `scale`]. + Default scale is 0.07. + """ + def __init__(self, mode, num_layers, state_size, bidirectional=False, + projection_size=None, scale=0.07): + super(RNNFused, self).__init__(mode=mode, num_layers=num_layers, + state_size=state_size, + bidirectional=bidirectional, + projection_size=projection_size, + scale=scale) + self.gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] + self.num_layers = num_layers + self.num_hidden = state_size + self.dir = 2 if bidirectional else 1 + self.projection_size = projection_size + self.scale = scale + + def _init_weight(self, name, arr): + arr_len = arr.shape[0] + dtype = arr.dtype + size = self.num_hidden * self.dir * self.gates + if not self.projection_size: + # second layer size + size2 = (self.num_hidden * self.dir + self.num_hidden + 2) * size + input_size = (arr_len - (self.num_layers - 1) * size2) // \ + size - 2 - self.num_hidden + else: + # second layer size + size2 = (self.projection_size * self.dir + self.projection_size + 2) * size + size_projection = self.projection_size * self.num_hidden * self.num_layers * self.dir + input_size = (arr_len - size_projection - (self.num_layers - 1) * size2) // \ + size - 2 - self.projection_size + begin = 0 + if not self.projection_size: + for p in ['weight', 'bias']: + for l in range(self.num_layers): + for _ in range(self.dir): + for g in ['i2h', 'h2h']: + ni = input_size + if l != 0: + ni = self.num_hidden * self.dir + if g == 'h2h': + ni = self.num_hidden + shape0 = self.gates * self.num_hidden + if p == 'weight': + cur_len = shape0 * ni + _mx_np.random.uniform(-self.scale, self.scale, \ + size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) + else: + cur_len = shape0 + arr[begin:begin+cur_len] = 0.0 + begin += cur_len + else: + for p in ['weight', 'bias']: + for l in range(self.num_layers): + for _ in range(self.dir): + for g in ['i2h', 'h2h', 'h2r']: + if g != 'h2r' or p != 'bias': + ni = input_size + if l != 0: + ni = self.projection_size * dir + if g == 'h2h': + ni = self.projection_size + shape0 = self.gates * self.num_hidden + if p == 'weight': + cur_len = shape0 * ni + _mx_np.random.uniform(-self.scale, self.scale, \ + size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) + else: + cur_len = shape0 + arr[begin:begin+cur_len] = 0.0 + begin += cur_len diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 9d858305dbb4..a3543f62e202 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -160,8 +160,8 @@ def test_lstm_cpu_inference(): [[0.96265966, 0.96265966, 0.7564661, 0.7564661], [0.96265966, 0.96265966, 0.7564661, 0.7564661]]]) x = mx.np.ones(shape=(2, 2, 2)) - model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) - model.initialize(mx.init.One()) + model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True, param_initializer=mx.init.One()) + model.initialize() y = model(x).asnumpy() mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, From 296628ec46db93e84bd5590ca28104d96b263485 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 29 Jun 2021 10:18:00 -0700 Subject: [PATCH 07/19] fix lint --- python/mxnet/initializer.py | 45 +++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 1efd0e254d49..742d901d0816 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -733,7 +733,7 @@ class RNNFused(Initializer): scale : float, optional The bound on the range of the generated random values for weights. Values are generated from the range [-`scale`, `scale`]. - Default scale is 0.07. + Default scale is 0.07. """ def __init__(self, mode, num_layers, state_size, bidirectional=False, projection_size=None, scale=0.07): @@ -749,6 +749,7 @@ def __init__(self, mode, num_layers, state_size, bidirectional=False, self.projection_size = projection_size self.scale = scale + # pylint: disable=too-many-nested-blocks def _init_weight(self, name, arr): arr_len = arr.shape[0] dtype = arr.dtype @@ -766,18 +767,18 @@ def _init_weight(self, name, arr): size - 2 - self.projection_size begin = 0 if not self.projection_size: - for p in ['weight', 'bias']: - for l in range(self.num_layers): + for param in ['weight', 'bias']: + for layer_num in range(self.num_layers): for _ in range(self.dir): - for g in ['i2h', 'h2h']: - ni = input_size - if l != 0: - ni = self.num_hidden * self.dir - if g == 'h2h': - ni = self.num_hidden + for connect in ['i2h', 'h2h']: + num_inputs = input_size + if layer_num != 0: + num_inputs = self.num_hidden * self.dir + if connect == 'h2h': + num_inputs = self.num_hidden shape0 = self.gates * self.num_hidden - if p == 'weight': - cur_len = shape0 * ni + if param == 'weight': + cur_len = shape0 * num_inputs _mx_np.random.uniform(-self.scale, self.scale, \ size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) else: @@ -785,19 +786,19 @@ def _init_weight(self, name, arr): arr[begin:begin+cur_len] = 0.0 begin += cur_len else: - for p in ['weight', 'bias']: - for l in range(self.num_layers): + for param in ['weight', 'bias']: + for layer_num in range(self.num_layers): for _ in range(self.dir): - for g in ['i2h', 'h2h', 'h2r']: - if g != 'h2r' or p != 'bias': - ni = input_size - if l != 0: - ni = self.projection_size * dir - if g == 'h2h': - ni = self.projection_size + for connect in ['i2h', 'h2h', 'h2r']: + if connect != 'h2r' or param != 'bias': + num_inputs = input_size + if layer_num != 0: + num_inputs = self.projection_size * dir + if connect == 'h2h': + num_inputs = self.projection_size shape0 = self.gates * self.num_hidden - if p == 'weight': - cur_len = shape0 * ni + if param == 'weight': + cur_len = shape0 * num_inputs _mx_np.random.uniform(-self.scale, self.scale, \ size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) else: From 3947da9f26bd2dcb1553f82e22f7f13f7fd618c8 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 29 Jun 2021 19:44:26 -0700 Subject: [PATCH 08/19] fix tests --- python/mxnet/gluon/utils.py | 30 +++++++++++++++---------- python/mxnet/initializer.py | 25 ++++++++++++--------- tests/python/unittest/test_gluon_rnn.py | 27 ++++++++++++++-------- 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 30f5515447ff..5ba4422ce6e4 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -539,19 +539,25 @@ def split_rnn_params(param, mode, num_layers, input_size, hidden_size, bidirecti for d in ['l', 'r'][:dir]: for g in ['i2h', 'h2h', 'h2r']: if g != 'h2r' or p != 'bias': - ni = input_size - if l != 0: - ni = projection_size * dir - if g == 'h2h': - ni = projection_size - shape0 = gates * hidden_size - if p == 'weight': - cur_len = shape0 * ni + if g == 'h2r': + cur_len = projection_size * hidden_size param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ - param[begin:begin+cur_len].reshape(shape0, ni) + param[begin:begin+cur_len]. \ + reshape(projection_size, hidden_size) else: - cur_len = shape0 - param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ - param[begin:begin+cur_len].reshape(shape0,) + ni = input_size + if l != 0: + ni = projection_size * dir + if g == 'h2h': + ni = projection_size + shape0 = gates * hidden_size + if p == 'weight': + cur_len = shape0 * ni + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0, ni) + else: + cur_len = shape0 + param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \ + param[begin:begin+cur_len].reshape(shape0,) begin += cur_len return param_dict diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 742d901d0816..f313b587f710 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -791,17 +791,22 @@ def _init_weight(self, name, arr): for _ in range(self.dir): for connect in ['i2h', 'h2h', 'h2r']: if connect != 'h2r' or param != 'bias': - num_inputs = input_size - if layer_num != 0: - num_inputs = self.projection_size * dir - if connect == 'h2h': - num_inputs = self.projection_size - shape0 = self.gates * self.num_hidden - if param == 'weight': - cur_len = shape0 * num_inputs + if connect == 'h2r': + cur_len = self.projection_size * self.num_hidden _mx_np.random.uniform(-self.scale, self.scale, \ size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) else: - cur_len = shape0 - arr[begin:begin+cur_len] = 0.0 + num_inputs = input_size + if layer_num != 0: + num_inputs = self.projection_size * dir + if connect == 'h2h': + num_inputs = self.projection_size + shape0 = self.gates * self.num_hidden + if param == 'weight': + cur_len = shape0 * num_inputs + _mx_np.random.uniform(-self.scale, self.scale, \ + size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) + else: + cur_len = shape0 + arr[begin:begin+cur_len] = 0.0 begin += cur_len diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index a3543f62e202..73c2359a560f 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -100,10 +100,15 @@ def test_lstmp(): fused_layer_params = fused_layer.collect_params() stack_layer_params = stack_layer.collect_params() - for name, value in fused_layer_params.items(): - w = mx.np.random.uniform(size=value.shape) - value.set_data(w.copy()) - stack_layer_params[name[1:].replace('_', '.', 1)].set_data(w.copy()) + fused_weight_shape = fused_layer_params['rnn_param'].shape + print(fused_weight_shape) + w = mx.np.zeros(shape=fused_weight_shape) + fused_layer_params['rnn_param'].set_data(w.copy()) + fused_layer_params_split = split_rnn_params(w.copy(), 'lstm', num_layers, input_size,\ + hidden_size, False, projection_size=projection_size) + + for name, value in fused_layer_params_split.items(): + stack_layer_params[name[1:].replace('_', '.', 1)].set_data(value) fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, @@ -136,11 +141,15 @@ def test_lstmp(): fused_layer_params = fused_layer.collect_params() stack_layer_params = stack_layer.collect_params() - for name, value in fused_layer_params.items(): - w = mx.np.random.uniform(size=value.shape) - value.set_data(w.copy()) + fused_weight_shape = fused_layer_params['rnn_param'].shape + w = mx.np.zeros(shape=fused_weight_shape) + fused_layer_params['rnn_param'].set_data(w.copy()) + fused_layer_params_split = split_rnn_params(w.copy(), 'lstm', num_layers, input_size,\ + hidden_size, True, projection_size=projection_size) + + for name, value in fused_layer_params_split.items(): cur = name.split("_")[0] - stack_layer_params["{}.{}_cell.{}".format(cur[1:], name[0], name[len(cur)+1:])].set_data(w.copy()) + stack_layer_params["{}.{}_cell.{}".format(cur[1:], name[0], name[len(cur)+1:])].set_data(value) fused_output, fused_states = fused_layer(lstm_input.copy(), fused_begin_state) stack_output, stack_states = stack_layer.unroll(seq_len, lstm_input.copy(), begin_state=stack_begin_state, @@ -985,7 +994,7 @@ def test_conv_fill_shape(): @mx.util.use_np -def test_lstmp(): +def test_lstmp_cell(): nhid = 100 nproj = 64 cell = gluon.rnn.LSTMPCell(nhid, nproj) From 68e18fac99487d9d690148944f77f37091e1221a Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Jul 2021 11:05:34 -0700 Subject: [PATCH 09/19] update RNNFused initializer --- python/mxnet/gluon/parameter.py | 3 + python/mxnet/gluon/rnn/rnn_layer.py | 41 ++++++++++---- python/mxnet/initializer.py | 44 +++++++++++---- tests/python/gpu/test_gluon_gpu.py | 75 ++++++++++++++++++------- tests/python/unittest/test_gluon_rnn.py | 12 ++-- 5 files changed, 126 insertions(+), 49 deletions(-) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 16e2957c2551..f354f12db2d5 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -472,6 +472,9 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), ctx = [context.current_context()] if isinstance(ctx, Context): ctx = [ctx] + if isinstance(self.init, initializer.RNNFused): + self.init.set_initializer(init if init else default_init) + init = default_init = self.init if init is None: init = default_init if self.init is None else self.init if not shape_is_known(self.shape): diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 8161bc9595c2..8a854ca3f0c3 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -34,7 +34,9 @@ class _RNNLayer(HybridBlock): """Implementation of recurrent layers.""" def __init__(self, hidden_size, num_layers, layout, dropout, bidirectional, input_size, - param_initializer, mode, projection_size, + i2h_weight_initializer, h2h_weight_initializer, + i2h_bias_initializer, h2h_bias_initializer, + mode, projection_size, h2r_weight_initializer, lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, dtype, use_sequence_length=False, **kwargs): super(_RNNLayer, self).__init__(**kwargs) @@ -57,9 +59,14 @@ def __init__(self, hidden_size, num_layers, layout, self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] - if not param_initializer: - param_initializer = initializer.RNNFused(mode, num_layers, hidden_size,\ - bidirectional, projection_size) + param_initializer = initializer.RNNFused( + mode, num_layers, hidden_size, + bidirectional, projection_size, + i2h_weight_initializer=i2h_weight_initializer, + h2h_weight_initializer=h2h_weight_initializer, + i2h_bias_initializer=i2h_bias_initializer, + h2h_bias_initializer=h2h_bias_initializer, + h2r_weight_initializer=h2r_weight_initializer) self.rnn_param = Parameter('rnn_param', shape=(-1,), init=param_initializer, allow_deferred_init=True, dtype=dtype) @@ -280,11 +287,15 @@ class RNN(_RNNLayer): >>> output, hn = layer(input, h0) """ def __init__(self, hidden_size, num_layers=1, activation='relu', - layout='TNC', dropout=0, bidirectional=False, param_initializer=None, + layout='TNC', dropout=0, bidirectional=False, + i2h_weight_initializer=None, h2h_weight_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', input_size=0, dtype='float32', **kwargs): super(RNN, self).__init__(hidden_size, num_layers, layout, - dropout, bidirectional, input_size, param_initializer, - 'rnn_'+activation, None, None, None, False, + dropout, bidirectional, input_size, + i2h_weight_initializer, h2h_weight_initializer, + i2h_bias_initializer, h2h_bias_initializer, + 'rnn_'+activation, None, None, None, None, False, dtype, **kwargs) def state_info(self, batch_size=0): @@ -393,12 +404,16 @@ class LSTM(_RNNLayer): """ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, - param_initializer=None, projection_size=None, + i2h_weight_initializer=None, h2h_weight_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', + projection_size=None, h2r_weight_initializer=None, state_clip_min=None, state_clip_max=None, state_clip_nan=False, dtype='float32', **kwargs): super(LSTM, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, - param_initializer, 'lstm', projection_size, + i2h_weight_initializer, h2h_weight_initializer, + i2h_bias_initializer, h2h_bias_initializer, + 'lstm', projection_size, h2r_weight_initializer, state_clip_min, state_clip_max, state_clip_nan, dtype, **kwargs) @@ -498,10 +513,14 @@ class GRU(_RNNLayer): """ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, - param_initializer=None, dtype='float32', **kwargs): + i2h_weight_initializer=None, h2h_weight_initializer=None, + i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', + dtype='float32', **kwargs): super(GRU, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, - param_initializer, 'gru', None, None, None, False, + i2h_weight_initializer, h2h_weight_initializer, + i2h_bias_initializer, h2h_bias_initializer, + 'gru', None, None, None, None, False, dtype, **kwargs) def state_info(self, batch_size=0): diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index f313b587f710..b317e9e18f9f 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -736,23 +736,34 @@ class RNNFused(Initializer): Default scale is 0.07. """ def __init__(self, mode, num_layers, state_size, bidirectional=False, - projection_size=None, scale=0.07): + projection_size=None, scale=0.07, i2h_weight_initializer=None, + h2h_weight_initializer=None, i2h_bias_initializer=None, + h2h_bias_initializer=None, h2r_weight_initializer=None): super(RNNFused, self).__init__(mode=mode, num_layers=num_layers, state_size=state_size, bidirectional=bidirectional, projection_size=projection_size, - scale=scale) + scale=scale, + i2h_weight_initializer=i2h_weight_initializer, + h2h_weight_initializer=h2h_weight_initializer, + i2h_bias_initializer=i2h_bias_initializer, + h2h_bias_initializer=h2h_bias_initializer, + h2r_weight_initializer=h2r_weight_initializer) self.gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] self.num_layers = num_layers self.num_hidden = state_size self.dir = 2 if bidirectional else 1 self.projection_size = projection_size self.scale = scale + self._i2h_weight_initializer = i2h_weight_initializer + self._h2h_weight_initializer = h2h_weight_initializer + self._i2h_bias_initializer = i2h_bias_initializer + self._h2h_bias_initializer = h2h_bias_initializer + self._h2r_weight_initializer = h2r_weight_initializer # pylint: disable=too-many-nested-blocks def _init_weight(self, name, arr): arr_len = arr.shape[0] - dtype = arr.dtype size = self.num_hidden * self.dir * self.gates if not self.projection_size: # second layer size @@ -779,11 +790,9 @@ def _init_weight(self, name, arr): shape0 = self.gates * self.num_hidden if param == 'weight': cur_len = shape0 * num_inputs - _mx_np.random.uniform(-self.scale, self.scale, \ - size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) else: cur_len = shape0 - arr[begin:begin+cur_len] = 0.0 + self._init_util(param, connect, arr[begin:begin+cur_len]) begin += cur_len else: for param in ['weight', 'bias']: @@ -793,8 +802,6 @@ def _init_weight(self, name, arr): if connect != 'h2r' or param != 'bias': if connect == 'h2r': cur_len = self.projection_size * self.num_hidden - _mx_np.random.uniform(-self.scale, self.scale, \ - size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) else: num_inputs = input_size if layer_num != 0: @@ -804,9 +811,24 @@ def _init_weight(self, name, arr): shape0 = self.gates * self.num_hidden if param == 'weight': cur_len = shape0 * num_inputs - _mx_np.random.uniform(-self.scale, self.scale, \ - size=(cur_len,), dtype=dtype, out=arr[begin:begin+cur_len]) else: cur_len = shape0 - arr[begin:begin+cur_len] = 0.0 + self._init_util(param, connect, arr[begin:begin+cur_len]) begin += cur_len + + def _init_util(self, param, connect, arr): + name = "_{}_{}_initializer".format(connect, param) + init = getattr(self, name) + create(init)(InitDesc(name, {'__init__': init}), arr) + + def set_initializer(self, init): + self._i2h_weight_initializer = \ + init if not self._i2h_weight_initializer else 'uniform' + self._h2h_weight_initializer = \ + init if not self._h2h_weight_initializer else 'uniform' + self._i2h_bias_initializer = \ + init if not self._i2h_bias_initializer else 'zero' + self._h2h_bias_initializer = \ + init if not self._i2h_bias_initializer else 'zero' + self._h2r_weight_initializer = \ + init if not self._h2r_weight_initializer else 'uniform' diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 9b42e7452516..6abdb8461de6 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -74,7 +74,7 @@ def check_rnn_layer_w_rand_inputs(layer): @mx.util.use_np @assert_raises_cudnn_not_satisfied(min_version='7.2.1') -def test_lstmp(): +def test_lstmp_gpu(): hidden_size, projection_size = 3, 2 rtol, atol = 1e-2, 1e-2 batch_size, seq_len = 7, 11 @@ -97,8 +97,14 @@ def test_lstmp(): lstm_cell.initialize(ctx=ctx) layer_params = lstm_layer.collect_params() cell_params = lstm_cell.collect_params() + params = (weights['{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + + net_params_concat = mx.np.concatenate(params) + layer_params['rnn_param'].set_data(net_params_concat.copy()) for k, v in weights.items(): - layer_params['l0_' + k].set_data(v.copy()) cell_params[k].set_data(v.copy()) with autograd.record(): layer_output = lstm_layer(lstm_input.copy()) @@ -108,8 +114,10 @@ def test_lstmp(): assert_almost_equal(layer_output, cell_output, rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() + layer_params_split = split_rnn_params(layer_params['rnn_param'].data().copy(),\ + 'lstm', 1, input_size, hidden_size, False, projection_size=projection_size) for k, v in weights.items(): - layer_grad = layer_params['l0_' + k].grad() + layer_grad = layer_params_split['l0_' + k].grad() cell_grad = cell_params[k].grad() print('checking gradient for {}'.format('lstm0_l0_' + k)) assert_almost_equal(layer_grad, cell_grad, rtol=rtol, atol=atol) @@ -196,6 +204,39 @@ def forward(self, inpt): weights['{}0_h2h_bias'.format( d)] = mx.np.random.uniform(size=(size * 4,)) + if proj_size: + params = (weights['{}0_{}_{}'.format(d, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for d in ['l', 'r'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + else: + params = (weights['{}0_{}_{}'.format(d, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for d in ['l', 'r'] + for g in ['i2h', 'h2h']) + + net_params_concat = mx.np.concatenate(params) + if proj_size: + params_left = (weights['l0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + else: + params_left = (weights['l0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + if proj_size: + params_right = (weights['r0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h', 'h2r'] + if g != 'h2r' or t != 'bias') + else: + params_right = (weights['r0_{}_{}'.format(g, t)].reshape(-1) + for t in ['weight', 'bias'] + for g in ['i2h', 'h2h']) + net_ref_left_params = mx.np.concatenate(params_left) + net_ref_right_params = mx.np.concatenate(params_right) net = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=True) ref_net = RefBiLSTM(size, proj_size) @@ -203,10 +244,9 @@ def forward(self, inpt): ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() - for k in weights: - net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', '_lstm_fwd.l0').replace( - 'r0', '_lstm_bwd.l0')].set_data(weights[k]) + net_params['rnn_param'].set_data(net_params_concat) + ref_net_params['_lstm_fwd.rnn_param'].set_data(net_ref_left_params) + ref_net_params['_lstm_bwd.rnn_param'].set_data(net_ref_right_params) data = mx.np.random.uniform(size=(11, 10, in_size)) mx.test_utils.assert_allclose(net(data), ref_net(data), rtol=1e-6) @@ -214,12 +254,7 @@ def forward(self, inpt): def check_layer_bidirectional_varseqlen(size, in_size): - weights = {} - for d in ['l', 'r']: - weights['{}0_i2h_weight'.format(d)] = mx.np.random.uniform(size=(size*4, in_size)) - weights['{}0_h2h_weight'.format(d)] = mx.np.random.uniform(size=(size*4, size)) - weights['{}0_i2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) - weights['{}0_h2h_bias'.format(d)] = mx.np.random.uniform(size=(size*4,)) + weight = mx.np.random.uniform(size=(784,)) net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True) ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False) @@ -227,9 +262,8 @@ def check_layer_bidirectional_varseqlen(size, in_size): ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() - for k in weights: - net_params[k].set_data(weights[k]) - ref_net_params[k].set_data(weights[k]) + net_params['rnn_param'].set_data(weight) + ref_net_params['rnn_param'].set_data(weight) batch_size = 10 num_timesteps = 11 @@ -269,11 +303,10 @@ def check_layer_bidirectional_varseqlen(size, in_size): ref_net_params = ref_net.collect_params() - for k in weights: - net_grad = net_params[k].grad() - ref_net_grad = ref_net_params[k].grad() - assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(), - rtol=1e-2, atol=1e-6) + net_grad = net_params['rnn_param'].grad() + ref_net_grad = ref_net_params['rnn_param'].grad() + assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(), + rtol=1e-2, atol=1e-6) @assert_raises_cudnn_not_satisfied(min_version='5.1.10') diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 73c2359a560f..25dd4d5fdc64 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -164,13 +164,13 @@ def test_lstmp(): @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_cpu_inference(): # should behave the same as lstm cell - EXPECTED_LSTM_OUTPUT = np.array([[[0.7564661, 0.7564661, 0.96265966, 0.96265966], - [0.7564661, 0.7564661, 0.96265966, 0.96265966]], - [[0.96265966, 0.96265966, 0.7564661, 0.7564661], - [0.96265966, 0.96265966, 0.7564661, 0.7564661]]]) + EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], + [0.72045636, 0.72045636, 0.95215213, 0.95215213]], + [[0.95215213, 0.95215213, 0.72045636, 0.72045636], + [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) x = mx.np.ones(shape=(2, 2, 2)) - model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True, param_initializer=mx.init.One()) - model.initialize() + model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) + model.initialize(mx.init.One()) y = model(x).asnumpy() mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, From a77035228b4043ed490b9ae39ac4828d0a95e851 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Jul 2021 11:11:37 -0700 Subject: [PATCH 10/19] fix --- python/mxnet/initializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index b317e9e18f9f..20251dda8e51 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -819,7 +819,7 @@ def _init_weight(self, name, arr): def _init_util(self, param, connect, arr): name = "_{}_{}_initializer".format(connect, param) init = getattr(self, name) - create(init)(InitDesc(name, {'__init__': init}), arr) + create(init)(InitDesc(name, {'__init__': init}), arr) def set_initializer(self, init): self._i2h_weight_initializer = \ From 481bd94c1cab416e12dd922e73be9b4ee3be3ccc Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 1 Jul 2021 16:02:59 -0700 Subject: [PATCH 11/19] fix --- python/mxnet/initializer.py | 6 ++---- tests/python/gpu/test_gluon_gpu.py | 4 ++-- tests/python/unittest/test_gluon_rnn.py | 12 +++++++++++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 20251dda8e51..987b357f5756 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -736,14 +736,13 @@ class RNNFused(Initializer): Default scale is 0.07. """ def __init__(self, mode, num_layers, state_size, bidirectional=False, - projection_size=None, scale=0.07, i2h_weight_initializer=None, + projection_size=None, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer=None, h2h_bias_initializer=None, h2r_weight_initializer=None): super(RNNFused, self).__init__(mode=mode, num_layers=num_layers, state_size=state_size, bidirectional=bidirectional, projection_size=projection_size, - scale=scale, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, @@ -754,7 +753,6 @@ def __init__(self, mode, num_layers, state_size, bidirectional=False, self.num_hidden = state_size self.dir = 2 if bidirectional else 1 self.projection_size = projection_size - self.scale = scale self._i2h_weight_initializer = i2h_weight_initializer self._h2h_weight_initializer = h2h_weight_initializer self._i2h_bias_initializer = i2h_bias_initializer @@ -805,7 +803,7 @@ def _init_weight(self, name, arr): else: num_inputs = input_size if layer_num != 0: - num_inputs = self.projection_size * dir + num_inputs = self.projection_size * self.dir if connect == 'h2h': num_inputs = self.projection_size shape0 = self.gates * self.num_hidden diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 6abdb8461de6..2b3be1f4f9b8 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -114,10 +114,10 @@ def test_lstmp_gpu(): assert_almost_equal(layer_output, cell_output, rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() - layer_params_split = split_rnn_params(layer_params['rnn_param'].data().copy(),\ + layer_params_split = split_rnn_params(layer_params['rnn_param'].grad().copy(),\ 'lstm', 1, input_size, hidden_size, False, projection_size=projection_size) for k, v in weights.items(): - layer_grad = layer_params_split['l0_' + k].grad() + layer_grad = layer_params_split['l0_' + k] cell_grad = cell_params[k].grad() print('checking gradient for {}'.format('lstm0_l0_' + k)) assert_almost_equal(layer_grad, cell_grad, rtol=rtol, atol=atol) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 25dd4d5fdc64..31ac01cabf67 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -691,8 +691,18 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, inpu fused_weight_shape = fused_layer_params['rnn_param'].shape w = mx.np.zeros(shape=fused_weight_shape) - fused_layer_params['rnn_param'].set_data(w.copy()) fused_layer_params_split = split_rnn_params(w.copy(), mode, num_layers, input_size, hidden_size, bidirectional) + for name, value in fused_layer_params_split.items(): + if 'bias' in name: + fused_layer_params_split[name] = mx.np.random.normal(size=value.shape) + _dir = 2 if bidirectional else 1 + params = (fused_layer_params_split['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for l in range(num_layers) + for d in ['l', 'r'][:_dir] + for g in ['i2h', 'h2h']) + fused_params = mx.np.concatenate(params) + fused_layer_params['rnn_param'].set_data(fused_params.copy()) for name, value in fused_layer_params_split.items(): cur = name.split('_')[0] num = cur[1:] From e7a959a3ae33dd489ca5113c5196ba0ac9c7556c Mon Sep 17 00:00:00 2001 From: barry-jin Date: Fri, 2 Jul 2021 11:10:04 -0700 Subject: [PATCH 12/19] fix leak --- tests/python/unittest/test_gluon_rnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 31ac01cabf67..f15d935a2f19 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -691,7 +691,7 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, inpu fused_weight_shape = fused_layer_params['rnn_param'].shape w = mx.np.zeros(shape=fused_weight_shape) - fused_layer_params_split = split_rnn_params(w.copy(), mode, num_layers, input_size, hidden_size, bidirectional) + fused_layer_params_split = split_rnn_params(w, mode, num_layers, input_size, hidden_size, bidirectional) for name, value in fused_layer_params_split.items(): if 'bias' in name: fused_layer_params_split[name] = mx.np.random.normal(size=value.shape) @@ -702,7 +702,7 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, mode, num_layers, inpu for d in ['l', 'r'][:_dir] for g in ['i2h', 'h2h']) fused_params = mx.np.concatenate(params) - fused_layer_params['rnn_param'].set_data(fused_params.copy()) + fused_layer_params['rnn_param'].set_data(fused_params) for name, value in fused_layer_params_split.items(): cur = name.split('_')[0] num = cur[1:] From 5a498990a7d8390c6c50a9dd9a435b46619890c6 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 6 Jul 2021 09:44:03 -0700 Subject: [PATCH 13/19] fix --- tests/python/gpu/test_gluon_gpu.py | 6 +++--- tests/python/unittest/test_gluon_rnn.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 2b3be1f4f9b8..bc1da6417703 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -103,9 +103,9 @@ def test_lstmp_gpu(): if g != 'h2r' or t != 'bias') net_params_concat = mx.np.concatenate(params) - layer_params['rnn_param'].set_data(net_params_concat.copy()) + layer_params['rnn_param'].set_data(net_params_concat) for k, v in weights.items(): - cell_params[k].set_data(v.copy()) + cell_params[k].set_data(v) with autograd.record(): layer_output = lstm_layer(lstm_input.copy()) cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), layout='TNC', @@ -114,7 +114,7 @@ def test_lstmp_gpu(): assert_almost_equal(layer_output, cell_output, rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() - layer_params_split = split_rnn_params(layer_params['rnn_param'].grad().copy(),\ + layer_params_split = split_rnn_params(layer_params['rnn_param'].grad(),\ 'lstm', 1, input_size, hidden_size, False, projection_size=projection_size) for k, v in weights.items(): layer_grad = layer_params_split['l0_' + k] diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index f15d935a2f19..822475f62413 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -103,8 +103,8 @@ def test_lstmp(): fused_weight_shape = fused_layer_params['rnn_param'].shape print(fused_weight_shape) w = mx.np.zeros(shape=fused_weight_shape) - fused_layer_params['rnn_param'].set_data(w.copy()) - fused_layer_params_split = split_rnn_params(w.copy(), 'lstm', num_layers, input_size,\ + fused_layer_params['rnn_param'].set_data(w) + fused_layer_params_split = split_rnn_params(w, 'lstm', num_layers, input_size,\ hidden_size, False, projection_size=projection_size) for name, value in fused_layer_params_split.items(): @@ -143,8 +143,8 @@ def test_lstmp(): fused_weight_shape = fused_layer_params['rnn_param'].shape w = mx.np.zeros(shape=fused_weight_shape) - fused_layer_params['rnn_param'].set_data(w.copy()) - fused_layer_params_split = split_rnn_params(w.copy(), 'lstm', num_layers, input_size,\ + fused_layer_params['rnn_param'].set_data(w) + fused_layer_params_split = split_rnn_params(w, 'lstm', num_layers, input_size,\ hidden_size, True, projection_size=projection_size) for name, value in fused_layer_params_split.items(): From 9ce8ae17ff9db34a497bbb342408903fabeef27c Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 6 Jul 2021 12:19:53 -0700 Subject: [PATCH 14/19] fix --- tests/python/unittest/test_numpy_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1abbeba4fcf2..053fa9d64daa 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3458,9 +3458,9 @@ def forward(self, a): @use_np def test_npx_activation_log_sigmoid(): def np_log_sigmoid(x): - return _np.log(_np.divide(1.0, (1.0 + _np.exp(-x)))) + return onp.log(onp.divide(1.0, (1.0 + onp.exp(-x)))) def np_log_sigmoid_grad(x): - return _np.divide(1.0, _np.add(1.0, _np.exp(x))) + return onp.divide(1.0, onp.add(1.0, onp.exp(x))) class TestLogSigmoid(HybridBlock): def __init__(self): From fb7cc5d61bd37a28b012da97429cc99313e26808 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 7 Jul 2021 12:43:09 -0700 Subject: [PATCH 15/19] fix --- tests/python/unittest/test_gluon_rnn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 822475f62413..a7dc7811ed9c 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -167,10 +167,10 @@ def test_lstm_cpu_inference(): EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], [0.72045636, 0.72045636, 0.95215213, 0.95215213]], [[0.95215213, 0.95215213, 0.72045636, 0.72045636], - [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) - x = mx.np.ones(shape=(2, 2, 2)) + [0.95215213, 0.95215213, 0.72045636, 0.72045636]]], ctx=mx.cpu(0)) + x = mx.np.ones(shape=(2, 2, 2), ctx=mx.cpu(0)) model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) - model.initialize(mx.init.One()) + model.initialize(mx.init.One(), ctx=mx.cpu(0)) y = model(x).asnumpy() mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, From 7eb3ef5ac87144491802d017121db349121cc1bc Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 7 Jul 2021 16:24:39 -0700 Subject: [PATCH 16/19] update --- tests/python/gpu/test_gluon_gpu.py | 2 +- tests/python/unittest/test_gluon_rnn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index bc1da6417703..375801ae8dbe 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -74,7 +74,7 @@ def check_rnn_layer_w_rand_inputs(layer): @mx.util.use_np @assert_raises_cudnn_not_satisfied(min_version='7.2.1') -def test_lstmp_gpu(): +def test_lstmp(): hidden_size, projection_size = 3, 2 rtol, atol = 1e-2, 1e-2 batch_size, seq_len = 7, 11 diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index a7dc7811ed9c..b79d16e7dd01 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -1004,7 +1004,7 @@ def test_conv_fill_shape(): @mx.util.use_np -def test_lstmp_cell(): +def test_lstmp(): nhid = 100 nproj = 64 cell = gluon.rnn.LSTMPCell(nhid, nproj) From 24527ca871342d456465a8ec99bbe401f0252b12 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 8 Jul 2021 10:04:06 -0700 Subject: [PATCH 17/19] update centos cu102 to use cudnn8 --- ci/docker/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/docker/docker-compose.yml b/ci/docker/docker-compose.yml index c9e410efcccb..0a3f320f42b9 100644 --- a/ci/docker/docker-compose.yml +++ b/ci/docker/docker-compose.yml @@ -58,7 +58,7 @@ services: dockerfile: Dockerfile.build.centos7 target: base args: - BASE_IMAGE: nvidia/cuda:10.2-cudnn7-devel-centos7 + BASE_IMAGE: nvidia/cuda:10.2-cudnn8-devel-centos7 cache_from: - ${DOCKER_CACHE_REGISTRY}/build.centos7_gpu_cu102:latest centos7_gpu_cu110: From 476fe1efbcae61e625751ba1c78dfdf2f1ac657b Mon Sep 17 00:00:00 2001 From: barry-jin Date: Fri, 9 Jul 2021 12:40:59 -0700 Subject: [PATCH 18/19] fix --- tests/python/unittest/test_gluon_rnn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index b79d16e7dd01..31748bad75d7 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -167,10 +167,10 @@ def test_lstm_cpu_inference(): EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], [0.72045636, 0.72045636, 0.95215213, 0.95215213]], [[0.95215213, 0.95215213, 0.72045636, 0.72045636], - [0.95215213, 0.95215213, 0.72045636, 0.72045636]]], ctx=mx.cpu(0)) - x = mx.np.ones(shape=(2, 2, 2), ctx=mx.cpu(0)) + [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) + x = mx.np.ones(shape=(2, 2, 2)) model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) - model.initialize(mx.init.One(), ctx=mx.cpu(0)) + model.initialize(mx.init.One()) y = model(x).asnumpy() mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, From 340e7741839830922a3e9628e28d29a724d70e85 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Tue, 19 Oct 2021 16:40:48 -0700 Subject: [PATCH 19/19] fix conflict --- tests/python/gpu/test_gluon_gpu.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 9d3c79814f5f..94190abdb21e 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -114,15 +114,10 @@ def test_lstmp(): assert_almost_equal(layer_output, cell_output, rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() -<<<<<<< HEAD layer_params_split = split_rnn_params(layer_params['rnn_param'].grad(),\ 'lstm', 1, input_size, hidden_size, False, projection_size=projection_size) - for k, v in weights.items(): - layer_grad = layer_params_split['l0_' + k] -======= for k, _ in weights.items(): - layer_grad = layer_params['l0_' + k].grad() ->>>>>>> upstream/master + layer_grad = layer_params_split['l0_' + k] cell_grad = cell_params[k].grad() print('checking gradient for {}'.format('lstm0_l0_' + k)) assert_almost_equal(layer_grad, cell_grad, rtol=rtol, atol=atol)