Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import warnings
import tensorflow
from functools import wraps
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -64,7 +65,12 @@ def gelu(x: tf.Tensor) -> tf.Tensor:
Original paper
https://arxiv.org/abs/1606.08415
"""
return op_module.gelu(x)
def gelu_wrapper(x):
try:
return tensorflow.nn.gelu(x, approximate=True)
except AttributeError:
return op_module.gelu(x)
return (lambda x: gelu_wrapper(x))(x)


# TODO this is not a good way to do things. This is some global variable to which
Expand Down
17 changes: 10 additions & 7 deletions source/op/_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
"""
First-order derivatives and second-order derivatives for gelu function.
"""

import tensorflow
from tensorflow.python.framework import ops
from deepmd.env import op_module

@ops.RegisterGradient("Gelu")
def _gelu_cc (op, dy) :
return op_module.gelu_grad(dy, op.inputs[0])
try:
gelu = tensorflow.nn.gelu
except AttributeError:
@ops.RegisterGradient("Gelu")
def _gelu_cc (op, dy) :
return op_module.gelu_grad(dy, op.inputs[0])

@ops.RegisterGradient("GeluGrad")
def _gelu_grad_cc (op, dy) :
return [op_module.gelu_grad(dy, op.inputs[1]), op_module.gelu_grad_grad(dy, op.inputs[0], op.inputs[1])]
@ops.RegisterGradient("GeluGrad")
def _gelu_grad_cc (op, dy) :
return [op_module.gelu_grad(dy, op.inputs[1]), op_module.gelu_grad_grad(dy, op.inputs[0], op.inputs[1])]