From f20853db8f0084bfb53193a4287acbc39c52df81 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 09:21:58 -0500 Subject: [PATCH 1/5] add: fast gelu approx --- python/mlx/nn/layers/activations.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index dfd435cfd8..1df8b4ba11 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -176,7 +176,7 @@ def gelu_approx(x): @partial(mx.compile, shapeless=True) def gelu_fast_approx(x): - r"""A fast approximation to Gaussian Error Linear Unit. + r"""A fast approximation to Gaussian Error Linear Unit. See :func:`gelu` for the exact computation. @@ -185,11 +185,13 @@ def gelu_fast_approx(x): .. math:: - x = x \sigma\left(1.773 x\right) + x = x \sigma\left(1.702 x\right) where :math:`\sigma(\cdot)` is the logistic sigmoid. + + Reference: https://github.com/hendrycks/GELUs & https://arxiv.org/pdf/1606.08415.pdf """ - return x * mx.sigmoid(1.773 * x) + return x * mx.sigmoid(1.702 * x) def glu(x: mx.array, axis: int = -1) -> mx.array: From 5d9456826a980740f2dbdc4ee4c475746202005f Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 09:22:58 -0500 Subject: [PATCH 2/5] fix docs --- python/mlx/nn/layers/activations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 1df8b4ba11..53f98769b0 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -176,7 +176,7 @@ def gelu_approx(x): @partial(mx.compile, shapeless=True) def gelu_fast_approx(x): - r"""A fast approximation to Gaussian Error Linear Unit. + r"""A fast approximation to Gaussian Error Linear Unit. See :func:`gelu` for the exact computation. @@ -189,7 +189,9 @@ def gelu_fast_approx(x): where :math:`\sigma(\cdot)` is the logistic sigmoid. - Reference: https://github.com/hendrycks/GELUs & https://arxiv.org/pdf/1606.08415.pdf + References: + - https://github.com/hendrycks/GELUs + - https://arxiv.org/pdf/1606.08415.pdf """ return x * mx.sigmoid(1.702 * x) From 66fa314aaf104aad3e8ffbd41f3d71d08216b727 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 09:26:59 -0500 Subject: [PATCH 3/5] Update gelu_fast_approx function documentation --- python/mlx/nn/layers/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 53f98769b0..662b7bd0a2 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -189,7 +189,7 @@ def gelu_fast_approx(x): where :math:`\sigma(\cdot)` is the logistic sigmoid. - References: + References: - https://github.com/hendrycks/GELUs - https://arxiv.org/pdf/1606.08415.pdf """ From 0252bc64f1fef3bea40f18c4862be8343cf538b7 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 10:36:37 -0500 Subject: [PATCH 4/5] Update python/mlx/nn/layers/activations.py Co-authored-by: Awni Hannun --- python/mlx/nn/layers/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 662b7bd0a2..336b4292bd 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -191,7 +191,7 @@ def gelu_fast_approx(x): References: - https://github.com/hendrycks/GELUs - - https://arxiv.org/pdf/1606.08415.pdf + - https://arxiv.org/abs/1606.08415 """ return x * mx.sigmoid(1.702 * x) From 8e277524a4a2851b2fb6a6ec131e904245d6645d Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Mon, 26 Feb 2024 19:38:50 -0500 Subject: [PATCH 5/5] fix: test gelu --- python/tests/test_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2c83461798..99154d3f6c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -665,7 +665,7 @@ def test_gelu(self): y_hat1 = nn.gelu_approx(x) y_hat2 = nn.gelu_fast_approx(x) self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) - self.assertLess(mx.abs(y - y_hat2).max(), 0.02) + self.assertLess(mx.abs(y - y_hat2).max(), 0.025) def test_sin_pe(self): m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)