This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
[Activation] GELU precision mismatch between MXNet and PyTorch in the CPU version #18826
Copy link
Copy link
Closed
Closed
Copy link
Description
The CPU version of mx.npx.leaky_relu(x, act_type='gelu') has different precision from PyTorch.
The minimal reproducible example:
import mxnet as mx
mx.npx.set_np()
a = mx.np.random.normal(0, 1, (10000,))
b = mx.npx.leaky_relu(a, act_type='gelu')
c = a * 0.5 * (1.0 + mx.npx.erf(a / math.sqrt(2.0)))
import torch
a_torch = torch.from_numpy(a.asnumpy()).cuda()
b_torch = torch.nn.functional.gelu(a_torch)
assert_allclose(b_torch.cpu().numpy(), c.asnumpy(), 1E-4, 1E-4)
assert_allclose(b_torch.cpu().numpy(), b.asnumpy(), 1E-4, 1E-4) The GPU version has no issue:
import mxnet as mx
mx.npx.set_np()
a = mx.np.random.normal(0, 1, (10000,), ctx=mx.gpu())
b = mx.npx.leaky_relu(a, act_type='gelu')
c = a * 0.5 * (1.0 + mx.npx.erf(a / math.sqrt(2.0)))
import torch
a_torch = torch.from_numpy(a.asnumpy()).cuda()
b_torch = torch.nn.functional.gelu(a_torch)
assert_allclose(b_torch.cpu().numpy(), c.asnumpy(), 1E-4, 1E-4)
assert_allclose(b_torch.cpu().numpy(), b.asnumpy(), 1E-4, 1E-4) Error:
<ipython-input-48-6f3377797f65> in <module>
9 b_torch = torch.nn.functional.gelu(a_torch)
10 assert_allclose(b_torch.cpu().numpy(), c.asnumpy(), 1E-4, 1E-4)
---> 11 assert_allclose(b_torch.cpu().numpy(), b.asnumpy(), 1E-4, 1E-4)
~/.local/lib/python3.6/site-packages/numpy/testing/_private/utils.py in assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
1526 header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
1527 assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
-> 1528 verbose=verbose, header=header, equal_nan=equal_nan)
1529
1530
~/.local/lib/python3.6/site-packages/numpy/testing/_private/utils.py in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)
838 verbose=verbose, header=header,
839 names=('x', 'y'), precision=precision)
--> 840 raise AssertionError(msg)
841 except ValueError:
842 import traceback
AssertionError:
Not equal to tolerance rtol=0.0001, atol=0.0001
Mismatched elements: 2258 / 10000 (22.6%)
Max absolute difference: 0.0004735
Max relative difference: 0.8255573
x: array([ 0.684651, 0.508604, -0.165598, ..., 1.706593, 0.288036,
1.006167], dtype=float32)
y: array([ 0.68455 , 0.508554, -0.165716, ..., 1.706508, 0.288026,
1.005966], dtype=float32)