Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5f0efbb

Browse files
authored
[v2.0] RNN: use rnn_params (#20384)
* use rnn_params * add split rnn parameter in gluon.utils * update * update * use zero weight * add rnn fused parameter initializer * fix lint * fix tests * update RNNFused initializer * fix * fix * fix leak * fix * fix * fix * update * update centos cu102 to use cudnn8 * fix * fix conflict
1 parent 481eba7 commit 5f0efbb

File tree

8 files changed

+339
-122
lines changed

8 files changed

+339
-122
lines changed

ci/docker/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ services:
5858
dockerfile: Dockerfile.build.centos7
5959
target: base
6060
args:
61-
BASE_IMAGE: nvidia/cuda:10.2-cudnn7-devel-centos7
61+
BASE_IMAGE: nvidia/cuda:10.2-cudnn8-devel-centos7
6262
cache_from:
6363
- ${DOCKER_CACHE_REGISTRY}/build.centos7_gpu_cu102:latest
6464
centos7_gpu_cu110:

python/mxnet/gluon/parameter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
472472
ctx = [context.current_context()]
473473
if isinstance(ctx, Context):
474474
ctx = [ctx]
475+
if isinstance(self.init, initializer.RNNFused):
476+
self.init.set_initializer(init if init else default_init)
477+
init = default_init = self.init
475478
if init is None:
476479
init = default_init if self.init is None else self.init
477480
if not shape_is_known(self.shape):

python/mxnet/gluon/rnn/rnn_layer.py

Lines changed: 29 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
__all__ = ['RNN', 'LSTM', 'GRU']
2525

26-
from ... import np, npx, context
26+
from ... import np, npx, context, initializer
2727
from .. import HybridBlock, tensor_types
2828
from ..parameter import Parameter
2929
from ...util import use_np
@@ -50,11 +50,6 @@ def __init__(self, hidden_size, num_layers, layout,
5050
self._dropout = dropout
5151
self._dir = 2 if bidirectional else 1
5252
self._input_size = input_size
53-
self._i2h_weight_initializer = i2h_weight_initializer
54-
self._h2h_weight_initializer = h2h_weight_initializer
55-
self._i2h_bias_initializer = i2h_bias_initializer
56-
self._h2h_bias_initializer = h2h_bias_initializer
57-
self._h2r_weight_initializer = h2r_weight_initializer
5853
self._lstm_state_clip_min = lstm_state_clip_min
5954
self._lstm_state_clip_max = lstm_state_clip_max
6055
self._lstm_state_clip_nan = lstm_state_clip_nan
@@ -64,48 +59,17 @@ def __init__(self, hidden_size, num_layers, layout,
6459

6560
self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
6661

67-
ng, ni, nh = self._gates, input_size, hidden_size
68-
if not projection_size:
69-
for i in range(num_layers):
70-
for j in ['l', 'r'][:self._dir]:
71-
self._register_param('{}{}_i2h_weight'.format(j, i),
72-
shape=(ng*nh, ni),
73-
init=i2h_weight_initializer, dtype=dtype)
74-
self._register_param('{}{}_h2h_weight'.format(j, i),
75-
shape=(ng*nh, nh),
76-
init=h2h_weight_initializer, dtype=dtype)
77-
self._register_param('{}{}_i2h_bias'.format(j, i),
78-
shape=(ng*nh,),
79-
init=i2h_bias_initializer, dtype=dtype)
80-
self._register_param('{}{}_h2h_bias'.format(j, i),
81-
shape=(ng*nh,),
82-
init=h2h_bias_initializer, dtype=dtype)
83-
ni = nh * self._dir
84-
else:
85-
ps = self._projection_size
86-
for i in range(num_layers):
87-
for j in ['l', 'r'][:self._dir]:
88-
self._register_param('{}{}_i2h_weight'.format(j, i),
89-
shape=(ng*nh, ni),
90-
init=i2h_weight_initializer, dtype=dtype)
91-
self._register_param('{}{}_h2h_weight'.format(j, i),
92-
shape=(ng*nh, ps),
93-
init=h2h_weight_initializer, dtype=dtype)
94-
self._register_param('{}{}_i2h_bias'.format(j, i),
95-
shape=(ng*nh,),
96-
init=i2h_bias_initializer, dtype=dtype)
97-
self._register_param('{}{}_h2h_bias'.format(j, i),
98-
shape=(ng*nh,),
99-
init=h2h_bias_initializer, dtype=dtype)
100-
self._register_param('{}{}_h2r_weight'.format(j, i),
101-
shape=(ps, nh),
102-
init=h2r_weight_initializer, dtype=dtype)
103-
ni = ps * self._dir
104-
105-
def _register_param(self, name, shape, init, dtype):
106-
p = Parameter(name, shape=shape, init=init, allow_deferred_init=True, dtype=dtype)
107-
setattr(self, name, p)
108-
return p
62+
param_initializer = initializer.RNNFused(
63+
mode, num_layers, hidden_size,
64+
bidirectional, projection_size,
65+
i2h_weight_initializer=i2h_weight_initializer,
66+
h2h_weight_initializer=h2h_weight_initializer,
67+
i2h_bias_initializer=i2h_bias_initializer,
68+
h2h_bias_initializer=h2h_bias_initializer,
69+
h2r_weight_initializer=h2r_weight_initializer)
70+
71+
self.rnn_param = Parameter('rnn_param', shape=(-1,), init=param_initializer,
72+
allow_deferred_init=True, dtype=dtype)
10973

11074
def __repr__(self):
11175
s = '{name}({mapping}, {_layout}'
@@ -116,8 +80,7 @@ def __repr__(self):
11680
if self._dir == 2:
11781
s += ', bidirectional'
11882
s += ')'
119-
shape = self.l0_i2h_weight.shape
120-
mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
83+
mapping = '{0} -> {1}'.format(self._input_size if self._input_size else None, self._hidden_size)
12184
return s.format(name=self.__class__.__name__,
12285
mapping=mapping,
12386
**self.__dict__)
@@ -196,37 +159,26 @@ def forward(self, inputs, states, sequence_length=None):
196159
def infer_shape(self, inputs, *args):
197160
assert inputs.ndim == 3, \
198161
"Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"
199-
if not self._projection_size:
200-
step = self._hidden_size
201-
else:
202-
step = self._projection_size
203-
ni = inputs.shape[2]
204-
for i in range(self._num_layers):
205-
for j in ['l', 'r'][:self._dir]:
206-
name = '{}{}_i2h_weight'.format(j, i)
207-
getattr(self, name).shape = (self._gates*self._hidden_size, ni)
208-
ni = step * self._dir
162+
self._input_size = inputs.shape[2]
163+
ng, ni, nh = self._gates, inputs.shape[2], self._hidden_size
164+
165+
size = nh * self._dir * ng
166+
size1 = (ni + nh + 2) * size # first layer size
167+
size2 = (nh * self._dir + nh + 2) * size # second layer size
168+
if self._projection_size:
169+
size1 = (ni + self._projection_size + 2) * size # first layer size
170+
size2 = (self._projection_size * self._dir + \
171+
self._projection_size + 2) * size # second layer size
172+
param_size = size1 + (self._num_layers - 1) * size2
173+
if self._projection_size:
174+
param_size += self._projection_size * nh * self._num_layers * self._dir
175+
self.rnn_param.shape = (param_size, )
209176

210177
def _forward_kernel(self, inputs, states, sequence_length):
211178
""" forward using CUDNN or CPU kenrel"""
212179
ctx = inputs.ctx
213180
if self._layout == 'NTC':
214181
inputs = np.swapaxes(inputs, 0, 1)
215-
if self._projection_size is None:
216-
params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1)
217-
for t in ['weight', 'bias']
218-
for l in range(self._num_layers)
219-
for d in ['l', 'r'][:self._dir]
220-
for g in ['i2h', 'h2h'])
221-
else:
222-
params = (getattr(self, '{}{}_{}_{}'.format(d, l, g, t)).data(ctx).reshape(-1)
223-
for t in ['weight', 'bias']
224-
for l in range(self._num_layers)
225-
for d in ['l', 'r'][:self._dir]
226-
for g in ['i2h', 'h2h', 'h2r']
227-
if g != 'h2r' or t != 'bias')
228-
229-
params = np.concatenate(params, axis=0)
230182

231183
if self._use_sequence_length:
232184
rnn_args = states + [sequence_length]
@@ -238,7 +190,8 @@ def _forward_kernel(self, inputs, states, sequence_length):
238190
new_args = args.as_in_ctx(ctx)
239191
rnn_args_ctx.append(new_args)
240192

241-
rnn = npx.rnn(inputs, params, *rnn_args_ctx, use_sequence_length=self._use_sequence_length,
193+
rnn = npx.rnn(inputs, self.rnn_param.data().as_in_ctx(ctx), *rnn_args_ctx,
194+
use_sequence_length=self._use_sequence_length,
242195
state_size=self._hidden_size, projection_size=self._projection_size,
243196
num_layers=self._num_layers, bidirectional=self._dir == 2,
244197
p=self._dropout, state_outputs=True, mode=self._mode,

python/mxnet/gluon/utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,79 @@ def _check_block_input_np_ndarrays(inputs):
504504
for i in inputs:
505505
_check_block_input_np_ndarrays(i)
506506
# pylint: enable=no-else-raise
507+
508+
509+
# pylint: disable=too-many-nested-blocks
510+
def split_rnn_params(param, mode, num_layers, input_size, hidden_size, bidirectional=False, projection_size=None):
511+
"""Split rnn layer parameter into weight and bias in different layer.
512+
513+
Parameters
514+
----------
515+
param : ndarray
516+
The parameter of rnn layer.
517+
mode : str
518+
Mode of rnn. Supported modes: rnn_relu, rnn_tanh, lstm, gru
519+
num_layers : int, default 1
520+
Number of recurrent layers.
521+
input_size: int, default 0
522+
The number of expected features in the input x.
523+
If not specified, it will be inferred from input.
524+
hidden_size: int
525+
The number of features in the hidden state h.
526+
bidirectional: bool, default False
527+
If `True`, becomes a bidirectional RNN.
528+
projection_size: int, default None
529+
The number of features after projection.
530+
"""
531+
gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
532+
dir = 2 if bidirectional else 1
533+
param_dict = {}
534+
begin = 0
535+
if not projection_size:
536+
for p in ['weight', 'bias']:
537+
for l in range(num_layers):
538+
for d in ['l', 'r'][:dir]:
539+
for g in ['i2h', 'h2h']:
540+
ni = input_size
541+
if l != 0:
542+
ni = hidden_size * dir
543+
if g == 'h2h':
544+
ni = hidden_size
545+
shape0 = gates * hidden_size
546+
if p == 'weight':
547+
cur_len = shape0 * ni
548+
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
549+
param[begin:begin+cur_len].reshape(shape0, ni)
550+
else:
551+
cur_len = shape0
552+
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
553+
param[begin:begin+cur_len].reshape(shape0,)
554+
begin += cur_len
555+
else:
556+
for p in ['weight', 'bias']:
557+
for l in range(num_layers):
558+
for d in ['l', 'r'][:dir]:
559+
for g in ['i2h', 'h2h', 'h2r']:
560+
if g != 'h2r' or p != 'bias':
561+
if g == 'h2r':
562+
cur_len = projection_size * hidden_size
563+
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
564+
param[begin:begin+cur_len]. \
565+
reshape(projection_size, hidden_size)
566+
else:
567+
ni = input_size
568+
if l != 0:
569+
ni = projection_size * dir
570+
if g == 'h2h':
571+
ni = projection_size
572+
shape0 = gates * hidden_size
573+
if p == 'weight':
574+
cur_len = shape0 * ni
575+
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
576+
param[begin:begin+cur_len].reshape(shape0, ni)
577+
else:
578+
cur_len = shape0
579+
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
580+
param[begin:begin+cur_len].reshape(shape0,)
581+
begin += cur_len
582+
return param_dict

python/mxnet/initializer.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,3 +711,122 @@ def _init_weight(self, name, arr):
711711
# gate of the 4 LSTM gates, we modify the according values.
712712
num_hidden = int(arr.shape[0] / 4)
713713
arr[num_hidden:2*num_hidden] = self.forget_bias
714+
715+
716+
@register
717+
class RNNFused(Initializer):
718+
"""Initialize RNN fused parameter with bias part initialized to 0.0 and
719+
weight initialized with random values uniformly sampled from a given range.
720+
721+
Parameters
722+
----------
723+
mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required
724+
the type of RNN to compute
725+
num_layers : int (non-negative), required
726+
number of stacked layers
727+
state_size : int (non-negative), required
728+
size of the state for each layer
729+
bidirectional : boolean, optional, default=0
730+
whether to use bidirectional recurrent layers
731+
projection_size : int or None, optional, default='None'
732+
size of project size
733+
scale : float, optional
734+
The bound on the range of the generated random values for weights.
735+
Values are generated from the range [-`scale`, `scale`].
736+
Default scale is 0.07.
737+
"""
738+
def __init__(self, mode, num_layers, state_size, bidirectional=False,
739+
projection_size=None, i2h_weight_initializer=None,
740+
h2h_weight_initializer=None, i2h_bias_initializer=None,
741+
h2h_bias_initializer=None, h2r_weight_initializer=None):
742+
super(RNNFused, self).__init__(mode=mode, num_layers=num_layers,
743+
state_size=state_size,
744+
bidirectional=bidirectional,
745+
projection_size=projection_size,
746+
i2h_weight_initializer=i2h_weight_initializer,
747+
h2h_weight_initializer=h2h_weight_initializer,
748+
i2h_bias_initializer=i2h_bias_initializer,
749+
h2h_bias_initializer=h2h_bias_initializer,
750+
h2r_weight_initializer=h2r_weight_initializer)
751+
self.gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
752+
self.num_layers = num_layers
753+
self.num_hidden = state_size
754+
self.dir = 2 if bidirectional else 1
755+
self.projection_size = projection_size
756+
self._i2h_weight_initializer = i2h_weight_initializer
757+
self._h2h_weight_initializer = h2h_weight_initializer
758+
self._i2h_bias_initializer = i2h_bias_initializer
759+
self._h2h_bias_initializer = h2h_bias_initializer
760+
self._h2r_weight_initializer = h2r_weight_initializer
761+
762+
# pylint: disable=too-many-nested-blocks
763+
def _init_weight(self, name, arr):
764+
arr_len = arr.shape[0]
765+
size = self.num_hidden * self.dir * self.gates
766+
if not self.projection_size:
767+
# second layer size
768+
size2 = (self.num_hidden * self.dir + self.num_hidden + 2) * size
769+
input_size = (arr_len - (self.num_layers - 1) * size2) // \
770+
size - 2 - self.num_hidden
771+
else:
772+
# second layer size
773+
size2 = (self.projection_size * self.dir + self.projection_size + 2) * size
774+
size_projection = self.projection_size * self.num_hidden * self.num_layers * self.dir
775+
input_size = (arr_len - size_projection - (self.num_layers - 1) * size2) // \
776+
size - 2 - self.projection_size
777+
begin = 0
778+
if not self.projection_size:
779+
for param in ['weight', 'bias']:
780+
for layer_num in range(self.num_layers):
781+
for _ in range(self.dir):
782+
for connect in ['i2h', 'h2h']:
783+
num_inputs = input_size
784+
if layer_num != 0:
785+
num_inputs = self.num_hidden * self.dir
786+
if connect == 'h2h':
787+
num_inputs = self.num_hidden
788+
shape0 = self.gates * self.num_hidden
789+
if param == 'weight':
790+
cur_len = shape0 * num_inputs
791+
else:
792+
cur_len = shape0
793+
self._init_util(param, connect, arr[begin:begin+cur_len])
794+
begin += cur_len
795+
else:
796+
for param in ['weight', 'bias']:
797+
for layer_num in range(self.num_layers):
798+
for _ in range(self.dir):
799+
for connect in ['i2h', 'h2h', 'h2r']:
800+
if connect != 'h2r' or param != 'bias':
801+
if connect == 'h2r':
802+
cur_len = self.projection_size * self.num_hidden
803+
else:
804+
num_inputs = input_size
805+
if layer_num != 0:
806+
num_inputs = self.projection_size * self.dir
807+
if connect == 'h2h':
808+
num_inputs = self.projection_size
809+
shape0 = self.gates * self.num_hidden
810+
if param == 'weight':
811+
cur_len = shape0 * num_inputs
812+
else:
813+
cur_len = shape0
814+
self._init_util(param, connect, arr[begin:begin+cur_len])
815+
begin += cur_len
816+
817+
def _init_util(self, param, connect, arr):
818+
name = "_{}_{}_initializer".format(connect, param)
819+
init = getattr(self, name)
820+
create(init)(InitDesc(name, {'__init__': init}), arr)
821+
822+
def set_initializer(self, init):
823+
self._i2h_weight_initializer = \
824+
init if not self._i2h_weight_initializer else 'uniform'
825+
self._h2h_weight_initializer = \
826+
init if not self._h2h_weight_initializer else 'uniform'
827+
self._i2h_bias_initializer = \
828+
init if not self._i2h_bias_initializer else 'zero'
829+
self._h2h_bias_initializer = \
830+
init if not self._i2h_bias_initializer else 'zero'
831+
self._h2r_weight_initializer = \
832+
init if not self._h2r_weight_initializer else 'uniform'

0 commit comments

Comments
 (0)