Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[Activation] GELU precision mismatch between MXNet and PyTorch in the CPU version #18826

@sxjscience

Description

@sxjscience

The CPU version of mx.npx.leaky_relu(x, act_type='gelu') has different precision from PyTorch.

The minimal reproducible example:

import mxnet as mx
mx.npx.set_np()
a = mx.np.random.normal(0, 1, (10000,)) 
b = mx.npx.leaky_relu(a, act_type='gelu')
c = a * 0.5 * (1.0 + mx.npx.erf(a / math.sqrt(2.0)))

import torch
a_torch = torch.from_numpy(a.asnumpy()).cuda() 
b_torch = torch.nn.functional.gelu(a_torch)
assert_allclose(b_torch.cpu().numpy(), c.asnumpy(), 1E-4, 1E-4)  
assert_allclose(b_torch.cpu().numpy(), b.asnumpy(), 1E-4, 1E-4)  

The GPU version has no issue:

import mxnet as mx
mx.npx.set_np()
a = mx.np.random.normal(0, 1, (10000,), ctx=mx.gpu()) 
b = mx.npx.leaky_relu(a, act_type='gelu')
c = a * 0.5 * (1.0 + mx.npx.erf(a / math.sqrt(2.0)))

import torch
a_torch = torch.from_numpy(a.asnumpy()).cuda() 
b_torch = torch.nn.functional.gelu(a_torch)
assert_allclose(b_torch.cpu().numpy(), c.asnumpy(), 1E-4, 1E-4)  
assert_allclose(b_torch.cpu().numpy(), b.asnumpy(), 1E-4, 1E-4)  

@pengzhao-intel @ciyongch

Error:

<ipython-input-48-6f3377797f65> in <module>
      9 b_torch = torch.nn.functional.gelu(a_torch)
     10 assert_allclose(b_torch.cpu().numpy(), c.asnumpy(), 1E-4, 1E-4)
---> 11 assert_allclose(b_torch.cpu().numpy(), b.asnumpy(), 1E-4, 1E-4)

~/.local/lib/python3.6/site-packages/numpy/testing/_private/utils.py in assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
   1526     header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
   1527     assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
-> 1528                          verbose=verbose, header=header, equal_nan=equal_nan)
   1529 
   1530 

~/.local/lib/python3.6/site-packages/numpy/testing/_private/utils.py in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)
    838                                 verbose=verbose, header=header,
    839                                 names=('x', 'y'), precision=precision)
--> 840             raise AssertionError(msg)
    841     except ValueError:
    842         import traceback

AssertionError: 
Not equal to tolerance rtol=0.0001, atol=0.0001

Mismatched elements: 2258 / 10000 (22.6%)
Max absolute difference: 0.0004735
Max relative difference: 0.8255573
 x: array([ 0.684651,  0.508604, -0.165598, ...,  1.706593,  0.288036,
        1.006167], dtype=float32)
 y: array([ 0.68455 ,  0.508554, -0.165716, ...,  1.706508,  0.288026,
        1.005966], dtype=float32)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions