-
Notifications
You must be signed in to change notification settings - Fork 6.7k
RNN operator produces inconsistent gradients for h2h_bias for stacked RNNs #17818
Description
The RNN operator produces different inconsistent values for gradient for h2h_bias in the topmost stack in different mxnet variants.
This only happens when num_layers > 1. I compared the values after one forward pass and one backward pass. I compared between osx-cpu-mkl and linux-gpu-mkl (both built from source. Snapshot versions are not up-to-date, and the versions available are really flaky in this case). In the cpu variant, the gradient for is all zeros.
These differences are leading to exception in DJL in CPU context. I also confirmed these differences in the values exist in Python.
Error Message
I don't see any error in Python but in DJL we have observed Nan values during the training process.
To Reproduce
I used the following script to look at gradients.
import mxnet as mx
from mxnet import gpu, cpu, gluon
from mxnet import np, npx
from mxnet import autograd as ag
from mxnet.gluon import nn
npx.set_np()
def check_param(net):
param_dict = net.collect_params()
array = ('lstm0_{}{}_{}_{}'.format(d, l, g, t)
for t in ['weight', 'bias']
for l in range(2)
for d in ['l', 'r'][:1]
for g in ['i2h', 'h2h'])
for key in array:
param = param_dict[key]
print("checking param: " + str(param))
print("weight sum: " + str(param.data().sum()))
print("weight mean: " + str(param.data().mean()))
print("weight max: " + str(param.data().max()))
print("weight min: " + str(param.data().min()))
if param.grad_req != "null":
print("checking the gradient of para: " + str(param))
print("grad sum: " + str(param.grad().sum()))
print("grad mean: " + str(param.grad().mean()))
print("grad max: " + str(param.grad().max()))
print("grad min: " + str(param.grad().min()))
def print_ndarray_stats(ndarray, name) :
print("#####", name, "#####")
print("checking " + name)
print("sum: " + str(ndarray.sum()))
print("mean: " + str(ndarray.mean()))
print("max: " + str(ndarray.max()))
print("min: " + str(ndarray.min()))
print("Shape: " + str(ndarray.shape))
batch = 32
time = 28
channel = 28
state = 64
num_layers = 2
mx.random.seed(1234)
data = np.random.uniform(0, 10, size=(batch, time, channel))
mx.random.seed(1234)
labels = np.random.uniform(0, 1, size=(batch, time, state))
net = gluon.rnn.LSTM(state, num_layers=2, h2h_weight_initializer=mx.initializer.Xavier(), i2h_weight_initializer=mx.initializer.Xavier(), layout='NTC')
loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False, from_logits=True)
net.collect_params().initialize()
with ag.record():
z = net(data)
L = loss(z, labels).mean()
print_ndarray_stats(z, "OUTPUT")
print("Loss = ", L)
L.backward()
check_param(net)
Steps to reproduce
(Paste the commands you ran that produced the error.)
- Install the appropriate mxnet version
- Run the above script (python rnntest.py)